From 800222dc23d0ce5a15842ce875069366fcf0b256 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 23 Apr 2026 00:25:19 -0700 Subject: [PATCH] slim down to jemalloc + debug endpoint + stress tests only --- config/plano_config_schema.yaml | 8 - crates/brightstaff/src/handlers/debug.rs | 43 ---- crates/brightstaff/src/main.rs | 7 +- crates/brightstaff/src/state/memory.rs | 245 +--------------------- crates/brightstaff/src/state/mod.rs | 6 - crates/brightstaff/src/tracing/init.rs | 2 - crates/common/src/configuration.rs | 6 - tests/stress/routing_stress.py | 255 ----------------------- 8 files changed, 5 insertions(+), 567 deletions(-) delete mode 100644 tests/stress/routing_stress.py diff --git a/config/plano_config_schema.yaml b/config/plano_config_schema.yaml index d17e5e12..3439ebee 100644 --- a/config/plano_config_schema.yaml +++ b/config/plano_config_schema.yaml @@ -477,14 +477,6 @@ properties: connection_string: type: string description: Required when type is postgres. Supports environment variable substitution using $VAR or ${VAR} syntax. - ttl_seconds: - type: integer - minimum: 60 - description: TTL in seconds for in-memory state entries. Only applies when type is memory. Default 1800 (30 min). - max_entries: - type: integer - minimum: 100 - description: Maximum number of in-memory state entries. Only applies when type is memory. Default 10000. additionalProperties: false required: - type diff --git a/crates/brightstaff/src/handlers/debug.rs b/crates/brightstaff/src/handlers/debug.rs index 84ef3b87..58fbecd2 100644 --- a/crates/brightstaff/src/handlers/debug.rs +++ b/crates/brightstaff/src/handlers/debug.rs @@ -1,10 +1,8 @@ use bytes::Bytes; use http_body_util::combinators::BoxBody; use hyper::{Response, StatusCode}; -use std::sync::Arc; use super::full; -use crate::app_state::AppState; #[derive(serde::Serialize)] struct MemStats { @@ -30,7 +28,6 @@ pub async fn memstats() -> Result>, hyper: fn get_jemalloc_stats() -> MemStats { use tikv_jemalloc_ctl::{epoch, stats}; - // Advance the jemalloc stats epoch so numbers are fresh. if let Err(e) = epoch::advance() { return MemStats { allocated_bytes: 0, @@ -54,43 +51,3 @@ fn get_jemalloc_stats() -> MemStats { error: Some("jemalloc feature not enabled".to_string()), } } - -#[derive(serde::Serialize)] -struct StateSize { - entry_count: usize, - estimated_bytes: usize, - #[serde(skip_serializing_if = "Option::is_none")] - error: Option, -} - -/// Returns the number of entries and estimated byte size in the conversation state store. -pub async fn state_size( - state: Arc, -) -> Result>, hyper::Error> { - let result = match &state.state_storage { - Some(storage) => match storage.entry_stats().await { - Ok((count, bytes)) => StateSize { - entry_count: count, - estimated_bytes: bytes, - error: None, - }, - Err(e) => StateSize { - entry_count: 0, - estimated_bytes: 0, - error: Some(format!("{e}")), - }, - }, - None => StateSize { - entry_count: 0, - estimated_bytes: 0, - error: Some("no state_storage configured".to_string()), - }, - }; - - let json = serde_json::to_string(&result).unwrap(); - Ok(Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "application/json") - .body(full(json)) - .unwrap()) -} diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 9dfc9977..e24fa650 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -358,15 +358,11 @@ async fn init_state_storage( let storage: Arc = match storage_config.storage_type { common::configuration::StateStorageType::Memory => { - let ttl = storage_config.ttl_seconds.unwrap_or(1800); - let max = storage_config.max_entries.unwrap_or(10_000); info!( storage_type = "memory", - ttl_seconds = ttl, - max_entries = max, "initialized conversation state storage" ); - Arc::new(MemoryConversationalStorage::with_limits(ttl, max)) + Arc::new(MemoryConversationalStorage::new()) } common::configuration::StateStorageType::Postgres => { let connection_string = storage_config @@ -520,7 +516,6 @@ async fn dispatch( } (&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => cors_preflight(), (&Method::GET, "/debug/memstats") => debug::memstats().await, - (&Method::GET, "/debug/state_size") => debug::state_size(Arc::clone(&state)).await, _ => { debug!(method = %req.method(), path = %path, "no route found"); let mut not_found = Response::new(empty()); diff --git a/crates/brightstaff/src/state/memory.rs b/crates/brightstaff/src/state/memory.rs index c0c851e8..be4d8232 100644 --- a/crates/brightstaff/src/state/memory.rs +++ b/crates/brightstaff/src/state/memory.rs @@ -3,98 +3,21 @@ use async_trait::async_trait; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; -use tracing::{debug, info, warn}; +use tracing::{debug, warn}; -const DEFAULT_TTL_SECS: u64 = 1800; // 30 minutes -const DEFAULT_MAX_ENTRIES: usize = 10_000; -const EVICTION_INTERVAL_SECS: u64 = 60; - -/// In-memory storage backend for conversation state. -/// -/// Entries are evicted when they exceed `ttl_secs` or when the store grows -/// beyond `max_entries` (oldest-first by `created_at`). A background task -/// runs every 60 s to sweep expired entries. +/// In-memory storage backend for conversation state +/// Uses a HashMap wrapped in Arc> for thread-safe access #[derive(Clone)] pub struct MemoryConversationalStorage { storage: Arc>>, - max_entries: usize, } impl MemoryConversationalStorage { pub fn new() -> Self { - Self::with_limits(DEFAULT_TTL_SECS, DEFAULT_MAX_ENTRIES) - } - - pub fn with_limits(ttl_secs: u64, max_entries: usize) -> Self { - let storage = Arc::new(RwLock::new(HashMap::new())); - - let bg_storage = Arc::clone(&storage); - tokio::spawn(async move { - Self::eviction_loop(bg_storage, ttl_secs, max_entries).await; - }); - Self { - storage, - max_entries, + storage: Arc::new(RwLock::new(HashMap::new())), } } - - async fn eviction_loop( - storage: Arc>>, - ttl_secs: u64, - max_entries: usize, - ) { - let interval = std::time::Duration::from_secs(EVICTION_INTERVAL_SECS); - loop { - tokio::time::sleep(interval).await; - - let now = chrono::Utc::now().timestamp(); - let cutoff = now - ttl_secs as i64; - - let mut map = storage.write().await; - let before = map.len(); - - // Phase 1: remove expired entries - map.retain(|_, state| state.created_at > cutoff); - - // Phase 2: if still over capacity, drop oldest entries - if map.len() > max_entries { - let mut entries: Vec<(String, i64)> = - map.iter().map(|(k, v)| (k.clone(), v.created_at)).collect(); - entries.sort_by_key(|(_, ts)| *ts); - - let to_remove = map.len() - max_entries; - for (key, _) in entries.into_iter().take(to_remove) { - map.remove(&key); - } - } - - let evicted = before.saturating_sub(map.len()); - if evicted > 0 { - info!( - evicted, - remaining = map.len(), - "memory state store eviction sweep" - ); - } - } - } - - fn estimate_entry_bytes(state: &OpenAIConversationState) -> usize { - let base = std::mem::size_of::() - + state.response_id.len() - + state.model.len() - + state.provider.len(); - let items: usize = state - .input_items - .iter() - .map(|item| { - // Rough estimate: serialize to JSON and use its length as a proxy - serde_json::to_string(item).map(|s| s.len()).unwrap_or(64) - }) - .sum(); - base + items - } } impl Default for MemoryConversationalStorage { @@ -115,26 +38,6 @@ impl StateStorage for MemoryConversationalStorage { ); storage.insert(response_id, state); - - // Inline cap check so we don't wait for the background sweep - if storage.len() > self.max_entries { - let mut entries: Vec<(String, i64)> = storage - .iter() - .map(|(k, v)| (k.clone(), v.created_at)) - .collect(); - entries.sort_by_key(|(_, ts)| *ts); - - let to_remove = storage.len() - self.max_entries; - for (key, _) in entries.into_iter().take(to_remove) { - storage.remove(&key); - } - info!( - evicted = to_remove, - remaining = storage.len(), - "memory state store cap eviction on put" - ); - } - Ok(()) } @@ -177,13 +80,6 @@ impl StateStorage for MemoryConversationalStorage { Err(StateStorageError::NotFound(response_id.to_string())) } } - - async fn entry_stats(&self) -> Result<(usize, usize), StateStorageError> { - let storage = self.storage.read().await; - let count = storage.len(); - let bytes: usize = storage.values().map(Self::estimate_entry_bytes).sum(); - Ok((count, bytes)) - } } #[cfg(test)] @@ -739,137 +635,4 @@ mod tests { _ => panic!("Expected MessageContent::Items"), } } - - // ----------------------------------------------------------------------- - // Stress / eviction tests - // ----------------------------------------------------------------------- - - fn create_test_state_with_ts( - response_id: &str, - num_messages: usize, - created_at: i64, - ) -> OpenAIConversationState { - let mut state = create_test_state(response_id, num_messages); - state.created_at = created_at; - state - } - - #[tokio::test] - async fn test_max_entries_cap_evicts_oldest() { - let max = 100; - let storage = MemoryConversationalStorage::with_limits(3600, max); - - let now = chrono::Utc::now().timestamp(); - for i in 0..(max + 50) { - let state = - create_test_state_with_ts(&format!("resp_{i}"), 2, now - (max as i64) + i as i64); - storage.put(state).await.unwrap(); - } - - let (count, _) = storage.entry_stats().await.unwrap(); - assert_eq!(count, max, "store should be capped at max_entries"); - - // The oldest entries should have been evicted - assert!( - storage.get("resp_0").await.is_err(), - "oldest entry should be evicted" - ); - // The newest entry should still be present - let newest = format!("resp_{}", max + 49); - assert!( - storage.get(&newest).await.is_ok(), - "newest entry should be present" - ); - } - - #[tokio::test] - async fn test_ttl_eviction_removes_expired() { - let ttl_secs = 2; - let storage = MemoryConversationalStorage::with_limits(ttl_secs, 100_000); - - let now = chrono::Utc::now().timestamp(); - - // Insert entries that are already "old" (created_at in the past beyond TTL) - for i in 0..20 { - let state = - create_test_state_with_ts(&format!("old_{i}"), 1, now - (ttl_secs as i64) - 10); - storage.put(state).await.unwrap(); - } - - // Insert fresh entries - for i in 0..10 { - let state = create_test_state_with_ts(&format!("new_{i}"), 1, now); - storage.put(state).await.unwrap(); - } - - let (count_before, _) = storage.entry_stats().await.unwrap(); - assert_eq!(count_before, 30); - - // Wait for the eviction sweep (interval is 60s in prod, but the old - // entries are already past TTL so a manual trigger would clean them). - // For unit testing we directly call the eviction logic instead of waiting. - { - let bg_storage = storage.storage.clone(); - let cutoff = now - ttl_secs as i64; - let mut map = bg_storage.write().await; - map.retain(|_, state| state.created_at > cutoff); - } - - let (count_after, _) = storage.entry_stats().await.unwrap(); - assert_eq!( - count_after, 10, - "expired entries should be evicted, only fresh ones remain" - ); - - // Verify fresh entries are still present - for i in 0..10 { - assert!(storage.get(&format!("new_{i}")).await.is_ok()); - } - } - - #[tokio::test] - async fn test_entry_stats_reports_reasonable_size() { - let storage = MemoryConversationalStorage::with_limits(3600, 10_000); - - let now = chrono::Utc::now().timestamp(); - for i in 0..50 { - let state = create_test_state_with_ts(&format!("resp_{i}"), 5, now); - storage.put(state).await.unwrap(); - } - - let (count, bytes) = storage.entry_stats().await.unwrap(); - assert_eq!(count, 50); - assert!(bytes > 0, "estimated bytes should be positive"); - assert!( - bytes > 50 * 100, - "50 entries with 5 messages each should be at least a few KB" - ); - } - - #[tokio::test] - async fn test_high_concurrency_with_cap() { - let max = 500; - let storage = MemoryConversationalStorage::with_limits(3600, max); - let now = chrono::Utc::now().timestamp(); - - let mut handles = vec![]; - for i in 0..1000 { - let s = storage.clone(); - let handle = tokio::spawn(async move { - let state = create_test_state_with_ts(&format!("conc_{i}"), 3, now + i as i64); - s.put(state).await.unwrap(); - }); - handles.push(handle); - } - - for handle in handles { - handle.await.unwrap(); - } - - let (count, _) = storage.entry_stats().await.unwrap(); - assert!( - count <= max, - "store should never exceed max_entries, got {count}" - ); - } } diff --git a/crates/brightstaff/src/state/mod.rs b/crates/brightstaff/src/state/mod.rs index 794568a6..43454ee2 100644 --- a/crates/brightstaff/src/state/mod.rs +++ b/crates/brightstaff/src/state/mod.rs @@ -75,12 +75,6 @@ pub trait StateStorage: Send + Sync { /// Delete state for a response_id (optional, for cleanup) async fn delete(&self, response_id: &str) -> Result<(), StateStorageError>; - /// Return (entry_count, estimated_bytes) for observability. - /// Backends that cannot cheaply compute this may return (0, 0). - async fn entry_stats(&self) -> Result<(usize, usize), StateStorageError> { - Ok((0, 0)) - } - fn merge( &self, prev_state: &OpenAIConversationState, diff --git a/crates/brightstaff/src/tracing/init.rs b/crates/brightstaff/src/tracing/init.rs index 1457d52a..ed351148 100644 --- a/crates/brightstaff/src/tracing/init.rs +++ b/crates/brightstaff/src/tracing/init.rs @@ -109,8 +109,6 @@ pub fn init_tracer(tracing_config: Option<&Tracing>) -> &'static SdkTracerProvid let provider = SdkTracerProvider::builder() .with_batch_exporter(exporter) - .with_max_attributes_per_span(64) - .with_max_events_per_span(16) .build(); global::set_tracer_provider(provider.clone()); diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 2fbdfdec..028c8046 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -123,12 +123,6 @@ pub struct StateStorageConfig { #[serde(rename = "type")] pub storage_type: StateStorageType, pub connection_string: Option, - /// TTL in seconds for in-memory state entries (default: 1800 = 30 min). - /// Only applies when type is `memory`. - pub ttl_seconds: Option, - /// Maximum number of in-memory state entries (default: 10000). - /// Only applies when type is `memory`. - pub max_entries: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] diff --git a/tests/stress/routing_stress.py b/tests/stress/routing_stress.py deleted file mode 100644 index 152630cb..00000000 --- a/tests/stress/routing_stress.py +++ /dev/null @@ -1,255 +0,0 @@ -#!/usr/bin/env python3 -""" -Stress test for Plano routing service to detect memory leaks. - -Sends sustained traffic to the routing endpoint and monitors memory -via the /debug/memstats and /debug/state_size endpoints. - -Usage: - # Against a local Plano instance (docker or native) - python routing_stress.py --base-url http://localhost:12000 - - # Custom parameters - python routing_stress.py \ - --base-url http://localhost:12000 \ - --num-requests 5000 \ - --concurrency 20 \ - --poll-interval 5 \ - --growth-threshold 3.0 - -Requirements: - pip install httpx -""" -from __future__ import annotations - -import argparse -import asyncio -import json -import sys -import time -import uuid -from dataclasses import dataclass, field - -import httpx - - -@dataclass -class MemSnapshot: - timestamp: float - allocated_bytes: int - resident_bytes: int - state_entries: int - state_bytes: int - requests_completed: int - - -@dataclass -class StressResult: - snapshots: list[MemSnapshot] = field(default_factory=list) - total_requests: int = 0 - total_errors: int = 0 - elapsed_secs: float = 0.0 - passed: bool = True - failure_reason: str = "" - - -def make_routing_body(unique: bool = True) -> dict: - """Build a minimal chat-completions body for the routing endpoint.""" - return { - "model": "gpt-4o", - "messages": [ - {"role": "user", "content": f"test message {uuid.uuid4() if unique else 'static'}"} - ], - } - - -async def poll_debug_endpoints( - client: httpx.AsyncClient, - base_url: str, - requests_completed: int, -) -> MemSnapshot | None: - try: - mem_resp = await client.get(f"{base_url}/debug/memstats", timeout=5) - mem_data = mem_resp.json() - - state_resp = await client.get(f"{base_url}/debug/state_size", timeout=5) - state_data = state_resp.json() - - return MemSnapshot( - timestamp=time.time(), - allocated_bytes=mem_data.get("allocated_bytes", 0), - resident_bytes=mem_data.get("resident_bytes", 0), - state_entries=state_data.get("entry_count", 0), - state_bytes=state_data.get("estimated_bytes", 0), - requests_completed=requests_completed, - ) - except Exception as e: - print(f" [warn] failed to poll debug endpoints: {e}", file=sys.stderr) - return None - - -async def send_requests( - client: httpx.AsyncClient, - url: str, - count: int, - semaphore: asyncio.Semaphore, - counter: dict, -): - """Send `count` routing requests, respecting the concurrency semaphore.""" - for _ in range(count): - async with semaphore: - try: - body = make_routing_body(unique=True) - resp = await client.post(url, json=body, timeout=30) - if resp.status_code >= 400: - counter["errors"] += 1 - except Exception: - counter["errors"] += 1 - finally: - counter["completed"] += 1 - - -async def run_stress_test( - base_url: str, - num_requests: int, - concurrency: int, - poll_interval: float, - growth_threshold: float, -) -> StressResult: - result = StressResult() - routing_url = f"{base_url}/routing/v1/chat/completions" - - print(f"Stress test config:") - print(f" base_url: {base_url}") - print(f" routing_url: {routing_url}") - print(f" num_requests: {num_requests}") - print(f" concurrency: {concurrency}") - print(f" poll_interval: {poll_interval}s") - print(f" growth_threshold: {growth_threshold}x") - print() - - async with httpx.AsyncClient() as client: - # Take baseline snapshot - baseline = await poll_debug_endpoints(client, base_url, 0) - if baseline: - result.snapshots.append(baseline) - print(f"[baseline] allocated={baseline.allocated_bytes:,}B " - f"resident={baseline.resident_bytes:,}B " - f"state_entries={baseline.state_entries}") - else: - print("[warn] could not get baseline snapshot, continuing anyway") - - counter = {"completed": 0, "errors": 0} - semaphore = asyncio.Semaphore(concurrency) - - start = time.time() - - # Launch request sender and poller concurrently - sender = asyncio.create_task( - send_requests(client, routing_url, num_requests, semaphore, counter) - ) - - # Poll memory while requests are in flight - while not sender.done(): - await asyncio.sleep(poll_interval) - snapshot = await poll_debug_endpoints(client, base_url, counter["completed"]) - if snapshot: - result.snapshots.append(snapshot) - print( - f" [{counter['completed']:>6}/{num_requests}] " - f"allocated={snapshot.allocated_bytes:,}B " - f"resident={snapshot.resident_bytes:,}B " - f"state_entries={snapshot.state_entries} " - f"state_bytes={snapshot.state_bytes:,}B" - ) - - await sender - result.elapsed_secs = time.time() - start - result.total_requests = counter["completed"] - result.total_errors = counter["errors"] - - # Final snapshot - final = await poll_debug_endpoints(client, base_url, counter["completed"]) - if final: - result.snapshots.append(final) - - # Analyze results - print() - print(f"Completed {result.total_requests} requests in {result.elapsed_secs:.1f}s " - f"({result.total_errors} errors)") - - if len(result.snapshots) >= 2: - first = result.snapshots[0] - last = result.snapshots[-1] - - if first.resident_bytes > 0: - growth_ratio = last.resident_bytes / first.resident_bytes - print(f"Memory growth: {first.resident_bytes:,}B -> {last.resident_bytes:,}B " - f"({growth_ratio:.2f}x)") - - if growth_ratio > growth_threshold: - result.passed = False - result.failure_reason = ( - f"Memory grew {growth_ratio:.2f}x (threshold: {growth_threshold}x). " - f"Likely memory leak detected." - ) - print(f"FAIL: {result.failure_reason}") - else: - print(f"PASS: Memory growth {growth_ratio:.2f}x is within {growth_threshold}x threshold") - - print(f"State store: {last.state_entries} entries, {last.state_bytes:,}B") - else: - print("[warn] not enough snapshots to analyze memory growth") - - return result - - -def main(): - parser = argparse.ArgumentParser(description="Plano routing service stress test") - parser.add_argument("--base-url", default="http://localhost:12000", - help="Base URL of the Plano instance") - parser.add_argument("--num-requests", type=int, default=2000, - help="Total number of requests to send") - parser.add_argument("--concurrency", type=int, default=10, - help="Max concurrent requests") - parser.add_argument("--poll-interval", type=float, default=5.0, - help="Seconds between memory polls") - parser.add_argument("--growth-threshold", type=float, default=3.0, - help="Max allowed memory growth ratio (fail if exceeded)") - args = parser.parse_args() - - result = asyncio.run(run_stress_test( - base_url=args.base_url, - num_requests=args.num_requests, - concurrency=args.concurrency, - poll_interval=args.poll_interval, - growth_threshold=args.growth_threshold, - )) - - # Write results JSON for CI consumption - report = { - "passed": result.passed, - "failure_reason": result.failure_reason, - "total_requests": result.total_requests, - "total_errors": result.total_errors, - "elapsed_secs": result.elapsed_secs, - "snapshots": [ - { - "timestamp": s.timestamp, - "allocated_bytes": s.allocated_bytes, - "resident_bytes": s.resident_bytes, - "state_entries": s.state_entries, - "state_bytes": s.state_bytes, - "requests_completed": s.requests_completed, - } - for s in result.snapshots - ], - } - print() - print(json.dumps(report, indent=2)) - - sys.exit(0 if result.passed else 1) - - -if __name__ == "__main__": - main()