refac: split into modules I
This commit is contained in:
parent
078855ba9a
commit
90b6868f5a
5 changed files with 233 additions and 199 deletions
108
context_window.py
Normal file
108
context_window.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
"""Sliding-window context-trim helpers.
|
||||
|
||||
Mirrors what llama.cpp's context-shift used to do: count tokens with tiktoken
|
||||
(cl100k_base) when available, drop oldest non-system messages until the prompt
|
||||
fits inside (n_ctx - safety_margin).
|
||||
|
||||
Also owns the per-(endpoint, model) n_ctx cache that the routes populate from
|
||||
exceed_context_size_error bodies and from finish_reason=="length" signals.
|
||||
"""
|
||||
try:
|
||||
import tiktoken as _tiktoken
|
||||
_tiktoken_enc = _tiktoken.get_encoding("cl100k_base")
|
||||
except Exception:
|
||||
_tiktoken_enc = None
|
||||
|
||||
|
||||
def _count_message_tokens(messages: list) -> int:
|
||||
"""Approximate token count for a message list.
|
||||
|
||||
Uses tiktoken cl100k_base when available (within ~5-15% of llama tokenizers).
|
||||
Falls back to char/4 heuristic if tiktoken is unavailable.
|
||||
Formula follows OpenAI's per-message overhead: 4 tokens/message + content + 2 priming.
|
||||
"""
|
||||
if _tiktoken_enc is None:
|
||||
return sum(len(str(m.get("content", ""))) for m in messages) // 4
|
||||
|
||||
total = 2 # priming tokens
|
||||
for msg in messages:
|
||||
total += 4 # per-message role/separator overhead
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
total += len(_tiktoken_enc.encode(content))
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
total += len(_tiktoken_enc.encode(part.get("text", "")))
|
||||
return total
|
||||
|
||||
|
||||
def _trim_messages_for_context(
|
||||
messages: list,
|
||||
n_ctx: int,
|
||||
safety_margin: int = None,
|
||||
target_tokens: int = None,
|
||||
) -> list:
|
||||
"""Sliding-window trim — mirrors what llama.cpp context-shift used to do.
|
||||
|
||||
Keeps all system messages and the most recent non-system messages that fit
|
||||
within (n_ctx - safety_margin) tokens. Oldest non-system messages are dropped
|
||||
first (FIFO). The last message is always preserved.
|
||||
|
||||
safety_margin defaults to 1/4 of n_ctx to leave headroom for the generated
|
||||
response, including RAG tool results and tool call JSON synthesis.
|
||||
|
||||
target_tokens: if provided, overrides the (n_ctx - safety_margin) target.
|
||||
Pass a calibrated value when actual n_prompt_tokens is known from the error
|
||||
body so that tiktoken underestimation vs the backend tokenizer is corrected.
|
||||
"""
|
||||
if target_tokens is not None:
|
||||
target = target_tokens
|
||||
else:
|
||||
if safety_margin is None:
|
||||
safety_margin = n_ctx // 4
|
||||
target = n_ctx - safety_margin
|
||||
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||||
non_system = [m for m in messages if m.get("role") != "system"]
|
||||
|
||||
while len(non_system) > 1:
|
||||
if _count_message_tokens(system_msgs + non_system) <= target:
|
||||
break
|
||||
non_system.pop(0) # drop oldest non-system message
|
||||
|
||||
# Ensure the first non-system message is a user message (chat templates require it).
|
||||
# Drop any leading assistant/tool messages that were left after trimming.
|
||||
while non_system and non_system[0].get("role") != "user":
|
||||
non_system.pop(0)
|
||||
|
||||
return system_msgs + non_system
|
||||
|
||||
|
||||
def _calibrated_trim_target(msgs: list, n_ctx: int, actual_tokens: int) -> int:
|
||||
"""Return a tiktoken-scale trim target based on how much backend tokens must be shed.
|
||||
|
||||
actual_tokens includes messages + tool schemas + overhead as counted by the backend.
|
||||
_count_message_tokens only counts message text, so we cannot derive an accurate
|
||||
per-token scale from the ratio. Instead we compute the *delta* we need to remove
|
||||
in backend space, then convert just that delta to tiktoken scale (×1.2 buffer).
|
||||
|
||||
Example: actual=17993, n_ctx=16384, headroom=4096 → need to shed 5705 backend
|
||||
tokens → shed 6846 tiktoken tokens from messages.
|
||||
"""
|
||||
cur_tiktoken = _count_message_tokens(msgs)
|
||||
headroom = n_ctx // 4 # reserve for generated output
|
||||
max_prompt = n_ctx - headroom # desired max backend tokens in prompt
|
||||
to_shed = max(0, actual_tokens - max_prompt) # backend tokens we must drop
|
||||
# Convert to tiktoken scale with 20% buffer (tiktoken underestimates llama by ~15-20%)
|
||||
tiktoken_to_shed = int(to_shed * 1.2)
|
||||
return max(1, cur_tiktoken - tiktoken_to_shed)
|
||||
|
||||
|
||||
# Per-(endpoint, model) n_ctx cache.
|
||||
# Populated from two sources:
|
||||
# 1. 400 exceed_context_size_error body → n_ctx field
|
||||
# 2. finish_reason/done_reason == "length" in streaming → prompt_tokens + completion_tokens
|
||||
# Only used for proactive pre-trimming when n_ctx <= _CTX_TRIM_SMALL_LIMIT,
|
||||
# so large-context models (200k+ for coding) are never touched.
|
||||
_endpoint_nctx: dict[tuple[str, str], int] = {}
|
||||
_CTX_TRIM_SMALL_LIMIT = 32768 # only proactively trim models with n_ctx at or below this
|
||||
35
fingerprint.py
Normal file
35
fingerprint.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""Conversation fingerprinting for prompt-cache-aware routing."""
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _conversation_fingerprint(model: str, messages: Optional[list],
|
||||
prompt: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Stable hash over (model, first system + first user turn). That prefix
|
||||
determines whether the backend's prompt cache is reusable; later turns
|
||||
don't influence the routing decision because they extend the same prefix.
|
||||
Returns None when there is no usable prefix.
|
||||
"""
|
||||
parts: list[str] = [model or "_"]
|
||||
if messages:
|
||||
for m in messages:
|
||||
role = m.get("role") if isinstance(m, dict) else None
|
||||
if role not in ("system", "user"):
|
||||
continue
|
||||
content = m.get("content")
|
||||
if isinstance(content, list): # OpenAI multimodal parts
|
||||
content = "".join(
|
||||
p.get("text", "") for p in content
|
||||
if isinstance(p, dict) and p.get("type") == "text"
|
||||
)
|
||||
if not isinstance(content, str):
|
||||
continue
|
||||
parts.append(f"{role}:{content}")
|
||||
if role == "user":
|
||||
break
|
||||
elif prompt:
|
||||
parts.append(f"user:{prompt}")
|
||||
else:
|
||||
return None
|
||||
return hashlib.sha1("\x1f".join(parts).encode("utf-8", "replace")).hexdigest()
|
||||
66
images.py
Normal file
66
images.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
"""Image and timestamp helpers used by the Ollama/OpenAI request pipeline."""
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def iso8601_ns():
|
||||
ns = time.time_ns()
|
||||
sec, ns_rem = divmod(ns, 1_000_000_000)
|
||||
dt = datetime.fromtimestamp(sec, tz=timezone.utc)
|
||||
return (
|
||||
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}T"
|
||||
f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}."
|
||||
f"{ns_rem:09d}Z"
|
||||
)
|
||||
|
||||
|
||||
def is_base64(image_string):
|
||||
try:
|
||||
if isinstance(image_string, str) and base64.b64encode(base64.b64decode(image_string)) == image_string.encode():
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def resize_image_if_needed(image_data):
|
||||
try:
|
||||
# Check if already data-url
|
||||
if image_data.startswith("data:"):
|
||||
try:
|
||||
header, image_data = image_data.split(",", 1)
|
||||
except ValueError:
|
||||
pass
|
||||
# Decode the base64 image data
|
||||
image_bytes = base64.b64decode(image_data)
|
||||
with Image.open(io.BytesIO(image_bytes)) as image:
|
||||
if image.mode not in ("RGB", "L"):
|
||||
image = image.convert("RGB")
|
||||
|
||||
# Get current size
|
||||
width, height = image.size
|
||||
|
||||
# Calculate the new dimensions while maintaining aspect ratio
|
||||
if width > 512 or height > 512:
|
||||
aspect_ratio = width / height
|
||||
if aspect_ratio > 1: # Width is larger
|
||||
new_width = 512
|
||||
new_height = int(512 / aspect_ratio)
|
||||
else: # Height is larger
|
||||
new_height = 512
|
||||
new_width = int(512 * aspect_ratio)
|
||||
|
||||
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# Encode the resized image back to base64
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
resized_image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
return resized_image_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {e}")
|
||||
return None
|
||||
209
router.py
209
router.py
|
|
@ -73,120 +73,14 @@ _subscribers: Set[asyncio.Queue] = set()
|
|||
_subscribers_lock = asyncio.Lock()
|
||||
token_queue: asyncio.Queue[tuple[str, str, int, int]] = asyncio.Queue()
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# Secret handling
|
||||
# -------------------------------------------------------------
|
||||
def _mask_secrets(text: str) -> str:
|
||||
"""
|
||||
Mask common API key patterns to avoid leaking secrets in logs or error payloads.
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
# OpenAI-style keys (sk-...) and generic "api key" mentions
|
||||
text = re.sub(r"sk-[A-Za-z0-9]{4}[A-Za-z0-9_-]*", "sk-***redacted***", text)
|
||||
text = re.sub(r"(?i)(api[-_ ]key\s*[:=]\s*)([^\s]+)", r"\1***redacted***", text)
|
||||
return text
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Context-window sliding-window helpers
|
||||
# ------------------------------------------------------------------
|
||||
try:
|
||||
import tiktoken as _tiktoken
|
||||
_tiktoken_enc = _tiktoken.get_encoding("cl100k_base")
|
||||
except Exception:
|
||||
_tiktoken_enc = None
|
||||
|
||||
def _count_message_tokens(messages: list) -> int:
|
||||
"""Approximate token count for a message list.
|
||||
|
||||
Uses tiktoken cl100k_base when available (within ~5-15% of llama tokenizers).
|
||||
Falls back to char/4 heuristic if tiktoken is unavailable.
|
||||
Formula follows OpenAI's per-message overhead: 4 tokens/message + content + 2 priming.
|
||||
"""
|
||||
if _tiktoken_enc is None:
|
||||
return sum(len(str(m.get("content", ""))) for m in messages) // 4
|
||||
|
||||
total = 2 # priming tokens
|
||||
for msg in messages:
|
||||
total += 4 # per-message role/separator overhead
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
total += len(_tiktoken_enc.encode(content))
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
total += len(_tiktoken_enc.encode(part.get("text", "")))
|
||||
return total
|
||||
|
||||
def _trim_messages_for_context(
|
||||
messages: list,
|
||||
n_ctx: int,
|
||||
safety_margin: int = None,
|
||||
target_tokens: int = None,
|
||||
) -> list:
|
||||
"""Sliding-window trim — mirrors what llama.cpp context-shift used to do.
|
||||
|
||||
Keeps all system messages and the most recent non-system messages that fit
|
||||
within (n_ctx - safety_margin) tokens. Oldest non-system messages are dropped
|
||||
first (FIFO). The last message is always preserved.
|
||||
|
||||
safety_margin defaults to 1/4 of n_ctx to leave headroom for the generated
|
||||
response, including RAG tool results and tool call JSON synthesis.
|
||||
|
||||
target_tokens: if provided, overrides the (n_ctx - safety_margin) target.
|
||||
Pass a calibrated value when actual n_prompt_tokens is known from the error
|
||||
body so that tiktoken underestimation vs the backend tokenizer is corrected.
|
||||
"""
|
||||
if target_tokens is not None:
|
||||
target = target_tokens
|
||||
else:
|
||||
if safety_margin is None:
|
||||
safety_margin = n_ctx // 4
|
||||
target = n_ctx - safety_margin
|
||||
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||||
non_system = [m for m in messages if m.get("role") != "system"]
|
||||
|
||||
while len(non_system) > 1:
|
||||
if _count_message_tokens(system_msgs + non_system) <= target:
|
||||
break
|
||||
non_system.pop(0) # drop oldest non-system message
|
||||
|
||||
# Ensure the first non-system message is a user message (chat templates require it).
|
||||
# Drop any leading assistant/tool messages that were left after trimming.
|
||||
while non_system and non_system[0].get("role") != "user":
|
||||
non_system.pop(0)
|
||||
|
||||
return system_msgs + non_system
|
||||
|
||||
|
||||
|
||||
def _calibrated_trim_target(msgs: list, n_ctx: int, actual_tokens: int) -> int:
|
||||
"""Return a tiktoken-scale trim target based on how much backend tokens must be shed.
|
||||
|
||||
actual_tokens includes messages + tool schemas + overhead as counted by the backend.
|
||||
_count_message_tokens only counts message text, so we cannot derive an accurate
|
||||
per-token scale from the ratio. Instead we compute the *delta* we need to remove
|
||||
in backend space, then convert just that delta to tiktoken scale (×1.2 buffer).
|
||||
|
||||
Example: actual=17993, n_ctx=16384, headroom=4096 → need to shed 5705 backend
|
||||
tokens → shed 6846 tiktoken tokens from messages.
|
||||
"""
|
||||
cur_tiktoken = _count_message_tokens(msgs)
|
||||
headroom = n_ctx // 4 # reserve for generated output
|
||||
max_prompt = n_ctx - headroom # desired max backend tokens in prompt
|
||||
to_shed = max(0, actual_tokens - max_prompt) # backend tokens we must drop
|
||||
# Convert to tiktoken scale with 20% buffer (tiktoken underestimates llama by ~15-20%)
|
||||
tiktoken_to_shed = int(to_shed * 1.2)
|
||||
return max(1, cur_tiktoken - tiktoken_to_shed)
|
||||
|
||||
# Per-(endpoint, model) n_ctx cache.
|
||||
# Populated from two sources:
|
||||
# 1. 400 exceed_context_size_error body → n_ctx field
|
||||
# 2. finish_reason/done_reason == "length" in streaming → prompt_tokens + completion_tokens
|
||||
# Only used for proactive pre-trimming when n_ctx <= _CTX_TRIM_SMALL_LIMIT,
|
||||
# so large-context models (200k+ for coding) are never touched.
|
||||
_endpoint_nctx: dict[tuple[str, str], int] = {}
|
||||
_CTX_TRIM_SMALL_LIMIT = 32768 # only proactively trim models with n_ctx at or below this
|
||||
from security import _mask_secrets
|
||||
from context_window import (
|
||||
_count_message_tokens,
|
||||
_trim_messages_for_context,
|
||||
_calibrated_trim_target,
|
||||
_endpoint_nctx,
|
||||
_CTX_TRIM_SMALL_LIMIT,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Globals
|
||||
|
|
@ -463,36 +357,7 @@ _affinity_lock = asyncio.Lock()
|
|||
_AFFINITY_MAX_ENTRIES = 10000
|
||||
|
||||
|
||||
def _conversation_fingerprint(model: str, messages: Optional[list],
|
||||
prompt: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Stable hash over (model, first system + first user turn). That prefix
|
||||
determines whether the backend's prompt cache is reusable; later turns
|
||||
don't influence the routing decision because they extend the same prefix.
|
||||
Returns None when there is no usable prefix.
|
||||
"""
|
||||
parts: list[str] = [model or "_"]
|
||||
if messages:
|
||||
for m in messages:
|
||||
role = m.get("role") if isinstance(m, dict) else None
|
||||
if role not in ("system", "user"):
|
||||
continue
|
||||
content = m.get("content")
|
||||
if isinstance(content, list): # OpenAI multimodal parts
|
||||
content = "".join(
|
||||
p.get("text", "") for p in content
|
||||
if isinstance(p, dict) and p.get("type") == "text"
|
||||
)
|
||||
if not isinstance(content, str):
|
||||
continue
|
||||
parts.append(f"{role}:{content}")
|
||||
if role == "user":
|
||||
break
|
||||
elif prompt:
|
||||
parts.append(f"user:{prompt}")
|
||||
else:
|
||||
return None
|
||||
return hashlib.sha1("\x1f".join(parts).encode("utf-8", "replace")).hexdigest()
|
||||
from fingerprint import _conversation_fingerprint
|
||||
|
||||
# Database instance
|
||||
db: "TokenDatabase" = None
|
||||
|
|
@ -1447,61 +1312,7 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
|
|||
# Generate final response
|
||||
return await _make_chat_request(model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
|
||||
|
||||
def iso8601_ns():
|
||||
ns = time.time_ns()
|
||||
sec, ns_rem = divmod(ns, 1_000_000_000)
|
||||
dt = datetime.fromtimestamp(sec, tz=timezone.utc)
|
||||
return (
|
||||
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}T"
|
||||
f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}."
|
||||
f"{ns_rem:09d}Z"
|
||||
)
|
||||
|
||||
def is_base64(image_string):
|
||||
try:
|
||||
if isinstance(image_string, str) and base64.b64encode(base64.b64decode(image_string)) == image_string.encode():
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
def resize_image_if_needed(image_data):
|
||||
try:
|
||||
# Check if already data-url
|
||||
if image_data.startswith("data:"):
|
||||
try:
|
||||
header, image_data = image_data.split(",", 1)
|
||||
except ValueError:
|
||||
pass
|
||||
# Decode the base64 image data
|
||||
image_bytes = base64.b64decode(image_data)
|
||||
with Image.open(io.BytesIO(image_bytes)) as image:
|
||||
if image.mode not in ("RGB", "L"):
|
||||
image = image.convert("RGB")
|
||||
|
||||
# Get current size
|
||||
width, height = image.size
|
||||
|
||||
# Calculate the new dimensions while maintaining aspect ratio
|
||||
if width > 512 or height > 512:
|
||||
aspect_ratio = width / height
|
||||
if aspect_ratio > 1: # Width is larger
|
||||
new_width = 512
|
||||
new_height = int(512 / aspect_ratio)
|
||||
else: # Height is larger
|
||||
new_height = 512
|
||||
new_width = int(512 * aspect_ratio)
|
||||
|
||||
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# Encode the resized image back to base64
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
resized_image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
return resized_image_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {e}")
|
||||
return None
|
||||
from images import iso8601_ns, is_base64, resize_image_if_needed
|
||||
|
||||
def _strip_assistant_prefill(messages: list) -> list:
|
||||
"""Remove a trailing assistant message used as prefill.
|
||||
|
|
|
|||
14
security.py
Normal file
14
security.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
"""Secret-masking helpers used when logging or surfacing backend errors."""
|
||||
import re
|
||||
|
||||
|
||||
def _mask_secrets(text: str) -> str:
|
||||
"""
|
||||
Mask common API key patterns to avoid leaking secrets in logs or error payloads.
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
# OpenAI-style keys (sk-...) and generic "api key" mentions
|
||||
text = re.sub(r"sk-[A-Za-z0-9]{4}[A-Za-z0-9_-]*", "sk-***redacted***", text)
|
||||
text = re.sub(r"(?i)(api[-_ ]key\s*[:=]\s*)([^\s]+)", r"\1***redacted***", text)
|
||||
return text
|
||||
Loading…
Add table
Add a link
Reference in a new issue