mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-12 09:12:40 +02:00
feat: enhance caching mechanisms to prevent memory leaks
- Improved in-memory rate limiting by evicting timestamps outside the current window and cleaning up empty keys. - Updated LLM router service to cache context profiles and avoid redundant computations. - Introduced cache eviction logic for MCP tools and sandbox instances to manage memory usage effectively. - Added garbage collection triggers in chat streaming functions to reclaim resources promptly.
This commit is contained in:
parent
08829c110c
commit
f4b2ab0899
7 changed files with 127 additions and 60 deletions
|
|
@ -22,6 +22,7 @@ from app.services.llm_router_service import (
|
|||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
|
|
@ -389,7 +390,7 @@ def create_chat_litellm_from_agent_config(
|
|||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
try:
|
||||
return ChatLiteLLMRouter()
|
||||
return get_auto_mode_llm()
|
||||
except Exception as e:
|
||||
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class _TimeoutAwareSandbox(DaytonaSandbox):
|
|||
|
||||
_daytona_client: Daytona | None = None
|
||||
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
|
||||
_SANDBOX_CACHE_MAX_SIZE = 20
|
||||
THREAD_LABEL_KEY = "surfsense_thread"
|
||||
|
||||
|
||||
|
|
@ -144,6 +145,12 @@ async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
|
|||
return cached
|
||||
sandbox = await asyncio.to_thread(_find_or_create, key)
|
||||
_sandbox_cache[key] = sandbox
|
||||
|
||||
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
|
||||
oldest_key = next(iter(_sandbox_cache))
|
||||
_sandbox_cache.pop(oldest_key, None)
|
||||
logger.debug("Evicted oldest sandbox cache entry: %s", oldest_key)
|
||||
|
||||
return sandbox
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -27,9 +27,24 @@ from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
|
||||
_MCP_CACHE_MAX_SIZE = 50
|
||||
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
|
||||
|
||||
|
||||
def _evict_expired_mcp_cache() -> None:
|
||||
"""Remove expired entries from the MCP tools cache to prevent unbounded growth."""
|
||||
now = time.monotonic()
|
||||
expired = [
|
||||
k
|
||||
for k, (ts, _) in _mcp_tools_cache.items()
|
||||
if now - ts >= _MCP_CACHE_TTL_SECONDS
|
||||
]
|
||||
for k in expired:
|
||||
del _mcp_tools_cache[k]
|
||||
if expired:
|
||||
logger.debug("Evicted %d expired MCP cache entries", len(expired))
|
||||
|
||||
|
||||
def _create_dynamic_input_model_from_schema(
|
||||
tool_name: str,
|
||||
input_schema: dict[str, Any],
|
||||
|
|
@ -392,6 +407,8 @@ async def load_mcp_tools(
|
|||
List of LangChain StructuredTool instances
|
||||
|
||||
"""
|
||||
_evict_expired_mcp_cache()
|
||||
|
||||
now = time.monotonic()
|
||||
cached = _mcp_tools_cache.get(search_space_id)
|
||||
if cached is not None:
|
||||
|
|
@ -445,6 +462,11 @@ async def load_mcp_tools(
|
|||
)
|
||||
|
||||
_mcp_tools_cache[search_space_id] = (now, tools)
|
||||
|
||||
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
|
||||
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
|
||||
del _mcp_tools_cache[oldest_key]
|
||||
|
||||
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
|
||||
return tools
|
||||
|
||||
|
|
|
|||
|
|
@ -103,22 +103,24 @@ def _check_rate_limit_memory(
|
|||
now = time.monotonic()
|
||||
|
||||
with _memory_lock:
|
||||
# Evict timestamps outside the current window
|
||||
_memory_rate_limits[key] = [
|
||||
t for t in _memory_rate_limits[key] if now - t < window_seconds
|
||||
]
|
||||
timestamps = [t for t in _memory_rate_limits[key] if now - t < window_seconds]
|
||||
|
||||
if len(_memory_rate_limits[key]) >= max_requests:
|
||||
if not timestamps:
|
||||
_memory_rate_limits.pop(key, None)
|
||||
else:
|
||||
_memory_rate_limits[key] = timestamps
|
||||
|
||||
if len(timestamps) >= max_requests:
|
||||
rate_limit_logger.warning(
|
||||
f"Rate limit exceeded (in-memory fallback) on {scope} for IP {client_ip} "
|
||||
f"({len(_memory_rate_limits[key])}/{max_requests} in {window_seconds}s)"
|
||||
f"({len(timestamps)}/{max_requests} in {window_seconds}s)"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="RATE_LIMIT_EXCEEDED",
|
||||
)
|
||||
|
||||
_memory_rate_limits[key].append(now)
|
||||
_memory_rate_limits[key] = [*timestamps, now]
|
||||
|
||||
|
||||
def _check_rate_limit(
|
||||
|
|
|
|||
|
|
@ -250,6 +250,48 @@ class LLMRouterService:
|
|||
return len(instance._model_list)
|
||||
|
||||
|
||||
_cached_context_profile: dict | None = None
|
||||
_cached_context_profile_computed: bool = False
|
||||
|
||||
# Cached singleton instances keyed by (streaming,) to avoid re-creating on every call
|
||||
_router_instance_cache: dict[bool, "ChatLiteLLMRouter"] = {}
|
||||
|
||||
|
||||
def _get_cached_context_profile(router: Router) -> dict | None:
|
||||
"""Compute and cache the min context profile across all router deployments.
|
||||
|
||||
Called once on first ChatLiteLLMRouter creation; subsequent calls return
|
||||
the cached value. This avoids calling litellm.get_model_info() for every
|
||||
deployment on every request.
|
||||
"""
|
||||
global _cached_context_profile, _cached_context_profile_computed
|
||||
if _cached_context_profile_computed:
|
||||
return _cached_context_profile
|
||||
|
||||
from litellm import get_model_info
|
||||
|
||||
min_ctx: int | None = None
|
||||
for deployment in router.model_list:
|
||||
params = deployment.get("litellm_params", {})
|
||||
base_model = params.get("base_model") or params.get("model", "")
|
||||
try:
|
||||
info = get_model_info(base_model)
|
||||
ctx = info.get("max_input_tokens")
|
||||
if isinstance(ctx, int) and ctx > 0 and (min_ctx is None or ctx < min_ctx):
|
||||
min_ctx = ctx
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if min_ctx is not None:
|
||||
logger.info("ChatLiteLLMRouter profile: max_input_tokens=%d", min_ctx)
|
||||
_cached_context_profile = {"max_input_tokens": min_ctx}
|
||||
else:
|
||||
_cached_context_profile = None
|
||||
|
||||
_cached_context_profile_computed = True
|
||||
return _cached_context_profile
|
||||
|
||||
|
||||
class ChatLiteLLMRouter(BaseChatModel):
|
||||
"""
|
||||
A LangChain-compatible chat model that uses LiteLLM Router for load balancing.
|
||||
|
|
@ -260,6 +302,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context
|
||||
window across all router deployments so that deepagents
|
||||
SummarizationMiddleware can use fraction-based triggers.
|
||||
|
||||
**Singleton-ish**: Use ``get_auto_mode_llm()`` or call ``ChatLiteLLMRouter()``
|
||||
directly — instances without bound tools are cached per streaming flag to
|
||||
avoid per-request re-initialization overhead and memory growth.
|
||||
"""
|
||||
|
||||
# Use model_config for Pydantic v2 compatibility
|
||||
|
|
@ -281,14 +327,6 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
tool_choice: str | dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the ChatLiteLLMRouter.
|
||||
|
||||
Args:
|
||||
router: LiteLLM Router instance. If None, uses the global singleton.
|
||||
bound_tools: Pre-bound tools for tool calling
|
||||
tool_choice: Tool choice configuration
|
||||
"""
|
||||
try:
|
||||
super().__init__(**kwargs)
|
||||
resolved_router = router or LLMRouterService.get_router()
|
||||
|
|
@ -300,51 +338,20 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
"LLM Router not initialized. Call LLMRouterService.initialize() first."
|
||||
)
|
||||
|
||||
# Set profile so deepagents SummarizationMiddleware gets fraction-based triggers
|
||||
computed_profile = self._compute_min_context_profile()
|
||||
computed_profile = _get_cached_context_profile(self._router)
|
||||
if computed_profile is not None:
|
||||
object.__setattr__(self, "profile", computed_profile)
|
||||
|
||||
logger.info(
|
||||
f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models"
|
||||
logger.debug(
|
||||
"ChatLiteLLMRouter ready (models=%d, streaming=%s, has_tools=%s)",
|
||||
LLMRouterService.get_model_count(),
|
||||
self.streaming,
|
||||
bound_tools is not None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
|
||||
raise
|
||||
|
||||
def _compute_min_context_profile(self) -> dict | None:
|
||||
"""Derive a profile dict with max_input_tokens from router deployments.
|
||||
|
||||
Uses litellm.get_model_info to look up each deployment's context window
|
||||
and picks the *minimum* so that summarization triggers before ANY model
|
||||
in the pool overflows.
|
||||
"""
|
||||
from litellm import get_model_info
|
||||
|
||||
if not self._router:
|
||||
return None
|
||||
|
||||
min_ctx: int | None = None
|
||||
for deployment in self._router.model_list:
|
||||
params = deployment.get("litellm_params", {})
|
||||
base_model = params.get("base_model") or params.get("model", "")
|
||||
try:
|
||||
info = get_model_info(base_model)
|
||||
ctx = info.get("max_input_tokens")
|
||||
if (
|
||||
isinstance(ctx, int)
|
||||
and ctx > 0
|
||||
and (min_ctx is None or ctx < min_ctx)
|
||||
):
|
||||
min_ctx = ctx
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if min_ctx is not None:
|
||||
logger.info(f"ChatLiteLLMRouter profile: max_input_tokens={min_ctx}")
|
||||
return {"max_input_tokens": min_ctx}
|
||||
return None
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "litellm-router"
|
||||
|
|
@ -770,19 +777,28 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
return None
|
||||
|
||||
|
||||
def get_auto_mode_llm() -> ChatLiteLLMRouter | None:
|
||||
"""
|
||||
Get a ChatLiteLLMRouter instance for auto mode.
|
||||
def get_auto_mode_llm(
|
||||
*,
|
||||
streaming: bool = True,
|
||||
) -> ChatLiteLLMRouter | None:
|
||||
"""Return a cached ChatLiteLLMRouter for auto mode.
|
||||
|
||||
Returns:
|
||||
ChatLiteLLMRouter instance or None if router not initialized
|
||||
Base (no tools) instances are cached per ``streaming`` flag so we
|
||||
avoid re-constructing them on every request. ``bind_tools()`` still
|
||||
returns a fresh instance because bound tools differ per agent.
|
||||
"""
|
||||
if not LLMRouterService.is_initialized():
|
||||
logger.warning("LLM Router not initialized for auto mode")
|
||||
return None
|
||||
|
||||
cached = _router_instance_cache.get(streaming)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
return ChatLiteLLMRouter()
|
||||
instance = ChatLiteLLMRouter(streaming=streaming)
|
||||
_router_instance_cache[streaming] = instance
|
||||
return instance
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from app.services.llm_router_service import (
|
|||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
|
|
@ -221,7 +222,7 @@ async def get_search_space_llm_instance(
|
|||
logger.debug(
|
||||
f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}"
|
||||
)
|
||||
return ChatLiteLLMRouter(disable_streaming=disable_streaming)
|
||||
return get_auto_mode_llm(streaming=not disable_streaming)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ Supports loading LLM configurations from:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
|
@ -1476,6 +1477,16 @@ async def stream_new_chat(
|
|||
|
||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||
|
||||
# Trigger a GC pass so LangGraph agent graphs, tool closures, and
|
||||
# LLM wrappers with potential circular refs are reclaimed promptly.
|
||||
collected = gc.collect()
|
||||
if collected:
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] gc.collect() reclaimed %d objects (chat_id=%s)",
|
||||
collected,
|
||||
chat_id,
|
||||
)
|
||||
|
||||
|
||||
async def stream_resume_chat(
|
||||
chat_id: int,
|
||||
|
|
@ -1662,3 +1673,10 @@ async def stream_resume_chat(
|
|||
)
|
||||
|
||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||
collected = gc.collect()
|
||||
if collected:
|
||||
_perf_log.info(
|
||||
"[stream_resume] gc.collect() reclaimed %d objects (chat_id=%s)",
|
||||
collected,
|
||||
chat_id,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue