nomyo-router/cache.py

558 lines
21 KiB
Python
Raw Normal View History

2026-03-08 09:12:09 +01:00
"""
LLM Semantic Cache for NOMYO Router.
Strategy:
- Namespace: sha256(route :: model :: system_prompt)[:16] exact context isolation
- Cache key: hash(normalize(last_user_message), namespace) exact lookup
- Embedding: weighted mean of
α * embed(bm25_weighted(chat_history)) conversation context
1-α * embed(last_user_message) the actual question
with α = cache_history_weight (default 0.3).
- Exact-match caching (similarity=1.0) uses DummyEmbeddingProvider zero extra deps.
- Semantic caching (similarity<1.0) requires sentence-transformers. If missing the
library falls back to exact-match with a warning (lean Docker image behaviour).
- MOE models (moe-*) always bypass the cache.
- Token counts are never recorded for cache hits.
- Streaming cache hits are served as a single-chunk response.
- Privacy protection: responses that echo back user-identifying tokens from the system
prompt (names, emails, IDs) are stored WITHOUT an embedding. They remain findable
by exact-match for the same user but are invisible to cross-user semantic search.
Generic responses (capital of France "Paris") keep their embeddings and can still
produce cross-user semantic hits as intended.
2026-03-08 09:12:09 +01:00
"""
import hashlib
import math
import re
2026-03-08 09:12:09 +01:00
import time
import warnings
from collections import Counter
from typing import Any, Optional
# Lazily resolved once at first embed() call
_semantic_available: Optional[bool] = None
def _check_sentence_transformers() -> bool:
global _semantic_available
if _semantic_available is None:
try:
import sentence_transformers # noqa: F401
_semantic_available = True
except ImportError:
_semantic_available = False
return _semantic_available # type: ignore[return-value]
# ---------------------------------------------------------------------------
# BM25-weighted text representation of chat history
# ---------------------------------------------------------------------------
def _bm25_weighted_text(history: list[dict]) -> str:
"""
Produce a BM25-importance-weighted text string from chat history turns.
High-IDF (rare, domain-specific) terms are repeated proportionally to
their BM25 score so the downstream sentence-transformer embedding
naturally upweights topical signal and downweights stop words.
"""
docs = [m.get("content", "") for m in history if m.get("content")]
if not docs:
return ""
def _tok(text: str) -> list[str]:
return [w.lower() for w in text.split() if len(w) > 2]
tokenized = [_tok(d) for d in docs]
N = len(tokenized)
df: Counter = Counter()
for tokens in tokenized:
for term in set(tokens):
df[term] += 1
k1, b = 1.5, 0.75
avg_dl = sum(len(t) for t in tokenized) / max(N, 1)
term_scores: Counter = Counter()
for tokens in tokenized:
tf_c = Counter(tokens)
dl = len(tokens)
for term, tf in tf_c.items():
idf = math.log((N + 1) / (df[term] + 1)) + 1.0
score = idf * (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * dl / max(avg_dl, 1)))
term_scores[term] += score
top = term_scores.most_common(50)
if not top:
return " ".join(docs)
max_s = top[0][1]
out: list[str] = []
for term, score in top:
out.extend([term] * max(1, round(3 * score / max_s)))
return " ".join(out)
# ---------------------------------------------------------------------------
# LLMCache
# ---------------------------------------------------------------------------
class LLMCache:
"""
Thin async wrapper around async-semantic-llm-cache that adds:
- Route-aware namespace isolation
- Two-vector weighted-mean embedding (history context + question)
- Per-instance hit/miss counters
- Graceful fallback when sentence-transformers is absent
"""
def __init__(self, cfg: Any) -> None:
self._cfg = cfg
self._backend: Any = None
self._emb_cache: Any = None
self._semantic: bool = False
self._hits: int = 0
self._misses: int = 0
async def init(self) -> None:
from semantic_llm_cache.similarity import EmbeddingCache
# --- Backend ---
backend_type: str = self._cfg.cache_backend
if backend_type == "sqlite":
from semantic_llm_cache.backends.sqlite import SQLiteBackend
self._backend = SQLiteBackend(db_path=self._cfg.cache_db_path)
elif backend_type == "redis":
from semantic_llm_cache.backends.redis import RedisBackend
self._backend = RedisBackend(url=self._cfg.cache_redis_url)
await self._backend.ping()
else:
from semantic_llm_cache.backends.memory import MemoryBackend
self._backend = MemoryBackend()
# --- Embedding provider ---
if self._cfg.cache_similarity < 1.0:
if _check_sentence_transformers():
from semantic_llm_cache.similarity import create_embedding_provider
provider = create_embedding_provider("sentence-transformer")
self._emb_cache = EmbeddingCache(provider=provider)
self._semantic = True
print(
f"[cache] Semantic cache ready "
f"(similarity≥{self._cfg.cache_similarity}, backend={backend_type})"
)
else:
warnings.warn(
"[cache] sentence-transformers is not installed. "
"Falling back to exact-match caching (similarity=1.0). "
"Use the :semantic Docker image tag to enable semantic caching.",
RuntimeWarning,
stacklevel=2,
)
self._emb_cache = EmbeddingCache() # DummyEmbeddingProvider
print(f"[cache] Exact-match cache ready (backend={backend_type}) [semantic unavailable]")
else:
self._emb_cache = EmbeddingCache() # DummyEmbeddingProvider
print(f"[cache] Exact-match cache ready (backend={backend_type})")
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _namespace(self, route: str, model: str, system: str) -> str:
raw = f"{route}::{model}::{system}"
return hashlib.sha256(raw.encode()).hexdigest()[:16]
def _cache_key(self, namespace: str, last_user: str) -> str:
from semantic_llm_cache.utils import hash_prompt, normalize_prompt
return hash_prompt(normalize_prompt(last_user), namespace)
# ------------------------------------------------------------------
# Privacy helpers — prevent cross-user leakage of personal data
# ------------------------------------------------------------------
_IDENTITY_STOPWORDS = frozenset({
"user", "users", "name", "names", "email", "phone", "their", "they",
"this", "that", "with", "from", "have", "been", "also", "more",
"tags", "identity", "preference", "context",
})
# Patterns that unambiguously signal personal data in a response
_EMAIL_RE = re.compile(r'\b[\w.%+-]+@[\w.-]+\.[a-zA-Z]{2,}\b')
_UUID_RE = re.compile(
r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b',
re.IGNORECASE,
)
# Standalone numeric run ≥ 8 digits (common user/account IDs)
_NUMERIC_ID_RE = re.compile(r'\b\d{8,}\b')
def _extract_response_content(self, response_bytes: bytes) -> str:
"""Parse response bytes (Ollama or OpenAI format) and return the text content."""
try:
import orjson
data = orjson.loads(response_bytes)
if "choices" in data: # OpenAI ChatCompletion
return (data["choices"][0].get("message") or {}).get("content", "")
if "message" in data: # Ollama chat
return (data.get("message") or {}).get("content", "")
if "response" in data: # Ollama generate
return data.get("response", "")
except Exception:
pass
return ""
def _extract_personal_tokens(self, system: str) -> frozenset[str]:
"""
Extract user-identifying tokens from the system prompt.
Looks for:
- Email addresses anywhere in the system prompt
- Numeric IDs ( 4 digits) appearing after "id" keyword
- Proper-noun-like words from [Tags: identity] lines
"""
if not system:
return frozenset()
tokens: set[str] = set()
# Email addresses
for email in self._EMAIL_RE.findall(system):
tokens.add(email.lower())
# Numeric IDs: "id: 1234", "id=5678"
for uid in re.findall(r'\bid\s*[=:]?\s*(\d{4,})\b', system, re.IGNORECASE):
tokens.add(uid)
# Values from [Tags: identity] lines (e.g. "User's name is Andreas")
for line in re.findall(
r'\[Tags:.*?identity.*?\]\s*(.+?)(?:\n|$)', system, re.IGNORECASE
):
for word in re.findall(r'\b\w{4,}\b', line):
w = word.lower()
if w not in self._IDENTITY_STOPWORDS:
tokens.add(w)
return frozenset(tokens)
def _response_is_personalized(self, response_bytes: bytes, system: str) -> bool:
"""
Return True if the response contains user-personal information.
Two complementary checks:
1. Direct PII detection in the response content emails, UUIDs, long
numeric IDs. Catches data retrieved at runtime via tool calls that
never appears in the system prompt.
2. System-prompt token overlap words extracted from [Tags: identity]
lines that reappear verbatim in the response (catches names, etc.).
Such responses are stored WITHOUT a semantic embedding so they are
invisible to cross-user semantic search while still being cacheable for
the same user via exact-match.
"""
content = self._extract_response_content(response_bytes)
if not content:
# Can't parse → err on the side of caution
return bool(response_bytes)
# 1. Direct PII patterns — independent of what's in the system prompt
if (
self._EMAIL_RE.search(content)
or self._UUID_RE.search(content)
or self._NUMERIC_ID_RE.search(content)
):
return True
# 2. System-prompt identity tokens echoed back in the response
personal = self._extract_personal_tokens(system)
if personal:
content_lower = content.lower()
if any(token in content_lower for token in personal):
return True
return False
2026-03-08 09:12:09 +01:00
def _parse_messages(
self, messages: list[dict]
) -> tuple[str, list[dict], str]:
"""
Returns (system_prompt, prior_history_turns, last_user_message).
Multimodal content lists are reduced to their text parts.
"""
system = ""
turns: list[dict] = []
for m in messages:
role = m.get("role", "")
content = m.get("content", "")
if isinstance(content, list):
content = " ".join(
p.get("text", "")
for p in content
if isinstance(p, dict) and p.get("type") == "text"
)
if role == "system":
system = content
else:
turns.append({"role": role, "content": content})
last_user = ""
for m in reversed(turns):
if m["role"] == "user":
last_user = m["content"]
break
# History = all turns before the final user message
history = turns[:-1] if turns and turns[-1]["role"] == "user" else turns
return system, history, last_user
async def _build_embedding(
self, history: list[dict], last_user: str
) -> list[float] | None:
"""
Weighted mean of BM25-weighted history embedding and last-user embedding.
Returns None when not in semantic mode.
"""
if not self._semantic:
return None
import numpy as np
alpha: float = self._cfg.cache_history_weight # weight for history signal
q_vec = np.array(await self._emb_cache.aencode(last_user), dtype=float)
if not history:
# No history → use question embedding alone (alpha has no effect)
return q_vec.tolist()
h_text = _bm25_weighted_text(history)
h_vec = np.array(await self._emb_cache.aencode(h_text), dtype=float)
combined = alpha * h_vec + (1.0 - alpha) * q_vec
norm = float(np.linalg.norm(combined))
if norm > 0.0:
combined /= norm
return combined.tolist()
# ------------------------------------------------------------------
# Public interface: chat (handles both Ollama and OpenAI message lists)
# ------------------------------------------------------------------
async def get_chat(
self, route: str, model: str, messages: list[dict]
) -> bytes | None:
"""Return cached response bytes, or None on miss."""
if not self._backend:
return None
system, history, last_user = self._parse_messages(messages)
if not last_user:
return None
ns = self._namespace(route, model, system)
key = self._cache_key(ns, last_user)
print(
f"[cache] get_chat route={route} model={model} ns={ns} "
f"prompt={last_user[:80]!r} "
f"system_snippet={system[:120]!r}"
)
2026-03-08 09:12:09 +01:00
# 1. Exact key match
entry = await self._backend.get(key)
if entry is not None:
self._hits += 1
print(f"[cache] HIT (exact) ns={ns} prompt={last_user[:80]!r}")
2026-03-08 09:12:09 +01:00
return entry.response # type: ignore[return-value]
# 2. Semantic similarity match
if self._semantic and self._cfg.cache_similarity < 1.0:
emb = await self._build_embedding(history, last_user)
result = await self._backend.find_similar(
emb, threshold=self._cfg.cache_similarity, namespace=ns
)
if result is not None:
_, matched, sim = result
2026-03-08 09:12:09 +01:00
self._hits += 1
print(
f"[cache] HIT (semantic sim={sim:.3f}) ns={ns} "
f"prompt={last_user[:80]!r} matched={matched.prompt[:80]!r}"
)
2026-03-08 09:12:09 +01:00
return matched.response # type: ignore[return-value]
self._misses += 1
print(f"[cache] MISS ns={ns} prompt={last_user[:80]!r}")
2026-03-08 09:12:09 +01:00
return None
async def set_chat(
self, route: str, model: str, messages: list[dict], response_bytes: bytes
) -> None:
"""Store a response in the cache (fire-and-forget friendly)."""
if not self._backend:
return
system, history, last_user = self._parse_messages(messages)
if not last_user:
return
ns = self._namespace(route, model, system)
key = self._cache_key(ns, last_user)
# Privacy guard: check whether the response contains personal data.
personalized = self._response_is_personalized(response_bytes, system)
if personalized:
# Exact-match is only safe when the system prompt is user-specific
# (i.e. different per user → different namespace). When the system
# prompt is generic and shared across all users, the namespace is the
# same for everyone: storing even without an embedding would let any
# user who asks the identical question get another user's personal data
# via exact-match. In that case skip storage entirely.
system_is_user_specific = bool(self._extract_personal_tokens(system))
if not system_is_user_specific:
print(
f"[cache] SKIP personalized response with generic system prompt "
f"route={route} model={model} ns={ns} prompt={last_user[:80]!r} "
f"system_snippet={system[:120]!r}"
)
return
print(
f"[cache] set_chat route={route} model={model} ns={ns} "
f"personalized={personalized} "
f"prompt={last_user[:80]!r} "
f"system_snippet={system[:120]!r}"
)
# Store without embedding when personalized (invisible to semantic search
# across users, but still reachable by exact-match within this namespace).
2026-03-08 09:12:09 +01:00
emb = (
await self._build_embedding(history, last_user)
if self._semantic and self._cfg.cache_similarity < 1.0 and not personalized
2026-03-08 09:12:09 +01:00
else None
)
from semantic_llm_cache.config import CacheEntry
await self._backend.set(
key,
CacheEntry(
prompt=last_user,
response=response_bytes,
embedding=emb,
created_at=time.time(),
ttl=self._cfg.cache_ttl,
namespace=ns,
hit_count=0,
),
)
# ------------------------------------------------------------------
# Convenience wrappers for the generate route (prompt string, not messages)
# ------------------------------------------------------------------
async def get_generate(
self, model: str, prompt: str, system: str = ""
) -> bytes | None:
messages: list[dict] = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
return await self.get_chat("generate", model, messages)
async def set_generate(
self, model: str, prompt: str, system: str, response_bytes: bytes
) -> None:
messages: list[dict] = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
await self.set_chat("generate", model, messages, response_bytes)
# ------------------------------------------------------------------
# Management
# ------------------------------------------------------------------
def stats(self) -> dict:
total = self._hits + self._misses
return {
"hits": self._hits,
"misses": self._misses,
"hit_rate": round(self._hits / total, 3) if total else 0.0,
"semantic": self._semantic,
"backend": self._cfg.cache_backend,
"similarity_threshold": self._cfg.cache_similarity,
"history_weight": self._cfg.cache_history_weight,
}
async def clear(self) -> None:
if self._backend:
await self._backend.clear()
self._hits = 0
self._misses = 0
# ---------------------------------------------------------------------------
# Module-level singleton
# ---------------------------------------------------------------------------
_cache: LLMCache | None = None
async def init_llm_cache(cfg: Any) -> LLMCache | None:
"""Initialise the module-level cache singleton. Returns None if disabled."""
global _cache
if not cfg.cache_enabled:
print("[cache] Cache disabled (cache_enabled=false).")
return None
_cache = LLMCache(cfg)
await _cache.init()
return _cache
def get_llm_cache() -> LLMCache | None:
return _cache
# ---------------------------------------------------------------------------
# Helper: convert a stored Ollama-format non-streaming response to an
# OpenAI SSE single-chunk stream (used when a streaming OpenAI request
# hits the cache whose entry was populated from a non-streaming response).
# ---------------------------------------------------------------------------
def openai_nonstream_to_sse(cached_bytes: bytes, model: str) -> bytes:
"""
Wrap a stored OpenAI ChatCompletion JSON as a minimal single-chunk SSE stream.
The stored entry always uses the non-streaming ChatCompletion format so that
non-streaming cache hits can be served directly; this function adapts it for
streaming clients.
"""
import orjson, time as _time
try:
d = orjson.loads(cached_bytes)
content = (d.get("choices") or [{}])[0].get("message", {}).get("content", "")
chunk = {
"id": d.get("id", "cache-hit"),
"object": "chat.completion.chunk",
"created": d.get("created", int(_time.time())),
"model": d.get("model", model),
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
}
if d.get("usage"):
chunk["usage"] = d["usage"]
return f"data: {orjson.dumps(chunk).decode()}\n\ndata: [DONE]\n\n".encode()
except Exception as exc:
warnings.warn(
f"[cache] openai_nonstream_to_sse: corrupt cache entry, returning empty stream: {exc}",
RuntimeWarning,
stacklevel=2,
)
return b"data: [DONE]\n\n"