mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
fix memory regression: jemalloc, debug endpoints, state eviction, stress tests
Switch brightstaff to jemalloc to fix glibc malloc fragmentation causing OOM in prod routing service deployments. Add /debug/memstats and /debug/state_size endpoints for runtime observability. Add TTL eviction and max_entries cap to MemoryConversationalStorage. Cap tracing span attributes/events. Include routing stress tests proving zero per-request leak, and a Python load test for E2E validation.
This commit is contained in:
parent
980faef6be
commit
ec5d3660cd
13 changed files with 1550 additions and 1050 deletions
|
|
@ -472,6 +472,14 @@ properties:
|
|||
connection_string:
|
||||
type: string
|
||||
description: Required when type is postgres. Supports environment variable substitution using $VAR or ${VAR} syntax.
|
||||
ttl_seconds:
|
||||
type: integer
|
||||
minimum: 60
|
||||
description: TTL in seconds for in-memory state entries. Only applies when type is memory. Default 1800 (30 min).
|
||||
max_entries:
|
||||
type: integer
|
||||
minimum: 100
|
||||
description: Maximum number of in-memory state entries. Only applies when type is memory. Default 10000.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
|
|
|
|||
1707
crates/Cargo.lock
generated
1707
crates/Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -3,6 +3,10 @@ name = "brightstaff"
|
|||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["jemalloc"]
|
||||
jemalloc = ["tikv-jemallocator", "tikv-jemalloc-ctl"]
|
||||
|
||||
[dependencies]
|
||||
async-openai = "0.30.1"
|
||||
async-trait = "0.1"
|
||||
|
|
@ -35,6 +39,8 @@ serde_with = "3.13.0"
|
|||
strsim = "0.11"
|
||||
serde_yaml = "0.9.34"
|
||||
thiserror = "2.0.12"
|
||||
tikv-jemallocator = { version = "0.6", optional = true }
|
||||
tikv-jemalloc-ctl = { version = "0.6", features = ["stats"], optional = true }
|
||||
tokio = { version = "1.44.2", features = ["full"] }
|
||||
tokio-postgres = { version = "0.7", features = ["with-serde_json-1"] }
|
||||
tokio-stream = "0.1"
|
||||
|
|
|
|||
96
crates/brightstaff/src/handlers/debug.rs
Normal file
96
crates/brightstaff/src/handlers/debug.rs
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
use bytes::Bytes;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use hyper::{Response, StatusCode};
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::full;
|
||||
use crate::app_state::AppState;
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct MemStats {
|
||||
allocated_bytes: usize,
|
||||
resident_bytes: usize,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
/// Returns jemalloc memory statistics as JSON.
|
||||
/// Falls back to a stub when the jemalloc feature is disabled.
|
||||
pub async fn memstats() -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let stats = get_jemalloc_stats();
|
||||
let json = serde_json::to_string(&stats).unwrap();
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(full(json))
|
||||
.unwrap())
|
||||
}
|
||||
|
||||
#[cfg(feature = "jemalloc")]
|
||||
fn get_jemalloc_stats() -> MemStats {
|
||||
use tikv_jemalloc_ctl::{epoch, stats};
|
||||
|
||||
// Advance the jemalloc stats epoch so numbers are fresh.
|
||||
if let Err(e) = epoch::advance() {
|
||||
return MemStats {
|
||||
allocated_bytes: 0,
|
||||
resident_bytes: 0,
|
||||
error: Some(format!("failed to advance jemalloc epoch: {e}")),
|
||||
};
|
||||
}
|
||||
|
||||
MemStats {
|
||||
allocated_bytes: stats::allocated::read().unwrap_or(0),
|
||||
resident_bytes: stats::resident::read().unwrap_or(0),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "jemalloc"))]
|
||||
fn get_jemalloc_stats() -> MemStats {
|
||||
MemStats {
|
||||
allocated_bytes: 0,
|
||||
resident_bytes: 0,
|
||||
error: Some("jemalloc feature not enabled".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct StateSize {
|
||||
entry_count: usize,
|
||||
estimated_bytes: usize,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
/// Returns the number of entries and estimated byte size in the conversation state store.
|
||||
pub async fn state_size(
|
||||
state: Arc<AppState>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let result = match &state.state_storage {
|
||||
Some(storage) => match storage.entry_stats().await {
|
||||
Ok((count, bytes)) => StateSize {
|
||||
entry_count: count,
|
||||
estimated_bytes: bytes,
|
||||
error: None,
|
||||
},
|
||||
Err(e) => StateSize {
|
||||
entry_count: 0,
|
||||
estimated_bytes: 0,
|
||||
error: Some(format!("{e}")),
|
||||
},
|
||||
},
|
||||
None => StateSize {
|
||||
entry_count: 0,
|
||||
estimated_bytes: 0,
|
||||
error: Some("no state_storage configured".to_string()),
|
||||
},
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(full(json))
|
||||
.unwrap())
|
||||
}
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
pub mod agents;
|
||||
pub mod debug;
|
||||
pub mod function_calling;
|
||||
pub mod llm;
|
||||
pub mod models;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
#[cfg(feature = "jemalloc")]
|
||||
#[global_allocator]
|
||||
static ALLOC: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
|
||||
|
||||
use brightstaff::app_state::AppState;
|
||||
use brightstaff::handlers::agents::orchestrator::agent_chat;
|
||||
use brightstaff::handlers::debug;
|
||||
use brightstaff::handlers::empty;
|
||||
use brightstaff::handlers::function_calling::function_calling_chat_handler;
|
||||
use brightstaff::handlers::llm::llm_chat;
|
||||
|
|
@ -368,11 +373,15 @@ async fn init_state_storage(
|
|||
|
||||
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
|
||||
common::configuration::StateStorageType::Memory => {
|
||||
let ttl = storage_config.ttl_seconds.unwrap_or(1800);
|
||||
let max = storage_config.max_entries.unwrap_or(10_000);
|
||||
info!(
|
||||
storage_type = "memory",
|
||||
ttl_seconds = ttl,
|
||||
max_entries = max,
|
||||
"initialized conversation state storage"
|
||||
);
|
||||
Arc::new(MemoryConversationalStorage::new())
|
||||
Arc::new(MemoryConversationalStorage::with_limits(ttl, max))
|
||||
}
|
||||
common::configuration::StateStorageType::Postgres => {
|
||||
let connection_string = storage_config
|
||||
|
|
@ -456,6 +465,8 @@ async fn route(
|
|||
Ok(list_models(Arc::clone(&state.llm_providers)).await)
|
||||
}
|
||||
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => cors_preflight(),
|
||||
(&Method::GET, "/debug/memstats") => debug::memstats().await,
|
||||
(&Method::GET, "/debug/state_size") => debug::state_size(Arc::clone(&state)).await,
|
||||
_ => {
|
||||
debug!(method = %req.method(), path = %path, "no route found");
|
||||
let mut not_found = Response::new(empty());
|
||||
|
|
|
|||
|
|
@ -6,3 +6,5 @@ pub mod orchestrator_model;
|
|||
pub mod orchestrator_model_v1;
|
||||
pub mod router_model;
|
||||
pub mod router_model_v1;
|
||||
#[cfg(test)]
|
||||
mod stress_tests;
|
||||
|
|
|
|||
253
crates/brightstaff/src/router/stress_tests.rs
Normal file
253
crates/brightstaff/src/router/stress_tests.rs
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::router::llm::RouterService;
|
||||
use common::configuration::{SelectionPolicy, SelectionPreference, TopLevelRoutingPreference};
|
||||
use hermesllm::apis::openai::{Message, MessageContent, Role};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn make_messages(n: usize) -> Vec<Message> {
|
||||
(0..n)
|
||||
.map(|i| Message {
|
||||
role: if i % 2 == 0 {
|
||||
Role::User
|
||||
} else {
|
||||
Role::Assistant
|
||||
},
|
||||
content: Some(MessageContent::Text(format!(
|
||||
"This is message number {i} with some padding text to make it realistic."
|
||||
))),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn make_routing_prefs() -> Vec<TopLevelRoutingPreference> {
|
||||
vec![
|
||||
TopLevelRoutingPreference {
|
||||
name: "code_generation".to_string(),
|
||||
description: "Code generation and debugging tasks".to_string(),
|
||||
models: vec![
|
||||
"openai/gpt-4o".to_string(),
|
||||
"openai/gpt-4o-mini".to_string(),
|
||||
],
|
||||
selection_policy: SelectionPolicy {
|
||||
prefer: SelectionPreference::None,
|
||||
},
|
||||
},
|
||||
TopLevelRoutingPreference {
|
||||
name: "summarization".to_string(),
|
||||
description: "Summarizing documents and text".to_string(),
|
||||
models: vec![
|
||||
"anthropic/claude-3-sonnet".to_string(),
|
||||
"openai/gpt-4o-mini".to_string(),
|
||||
],
|
||||
selection_policy: SelectionPolicy {
|
||||
prefer: SelectionPreference::None,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
/// Stress test: exercise the full routing code path N times using a mock
|
||||
/// HTTP server and measure jemalloc allocated bytes before/after.
|
||||
///
|
||||
/// This catches:
|
||||
/// - Memory leaks in generate_request / parse_response
|
||||
/// - Leaks in reqwest connection handling
|
||||
/// - String accumulation in the router model
|
||||
/// - Fragmentation (jemalloc allocated vs resident)
|
||||
#[tokio::test]
|
||||
async fn stress_test_routing_determine_route() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
let router_url = format!("{}/v1/chat/completions", server.url());
|
||||
|
||||
let mock_response = serde_json::json!({
|
||||
"id": "chatcmpl-mock",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "arch-router",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "{\"route\": \"code_generation\"}"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 10, "total_tokens": 110}
|
||||
});
|
||||
|
||||
let _mock = server
|
||||
.mock("POST", "/v1/chat/completions")
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(mock_response.to_string())
|
||||
.expect_at_least(1)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let prefs = make_routing_prefs();
|
||||
let router_service = Arc::new(RouterService::new(
|
||||
Some(prefs.clone()),
|
||||
None,
|
||||
router_url,
|
||||
"Arch-Router".to_string(),
|
||||
"arch-router".to_string(),
|
||||
));
|
||||
|
||||
// Warm up: a few requests to stabilize allocator state
|
||||
for _ in 0..10 {
|
||||
let msgs = make_messages(5);
|
||||
let _ = router_service
|
||||
.determine_route(&msgs, "00-trace-span-01", None, "warmup")
|
||||
.await;
|
||||
}
|
||||
|
||||
// Snapshot memory after warmup
|
||||
let baseline = get_allocated();
|
||||
|
||||
let num_iterations = 2000;
|
||||
|
||||
for i in 0..num_iterations {
|
||||
let msgs = make_messages(5 + (i % 10));
|
||||
let inline = if i % 3 == 0 {
|
||||
Some(make_routing_prefs())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let _ = router_service
|
||||
.determine_route(&msgs, "00-trace-span-01", inline, &format!("req-{i}"))
|
||||
.await;
|
||||
}
|
||||
|
||||
let after = get_allocated();
|
||||
|
||||
let growth = after.saturating_sub(baseline);
|
||||
let growth_mb = growth as f64 / (1024.0 * 1024.0);
|
||||
let per_request = if num_iterations > 0 {
|
||||
growth / num_iterations
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
eprintln!("=== Routing Stress Test Results ===");
|
||||
eprintln!(" Iterations: {num_iterations}");
|
||||
eprintln!(" Baseline alloc: {} bytes", baseline);
|
||||
eprintln!(" Final alloc: {} bytes", after);
|
||||
eprintln!(" Growth: {} bytes ({growth_mb:.2} MB)", growth);
|
||||
eprintln!(" Per-request: {} bytes", per_request);
|
||||
|
||||
// Allow up to 256 bytes per request of retained growth (connection pool, etc.)
|
||||
// A true leak would show thousands of bytes per request.
|
||||
assert!(
|
||||
per_request < 256,
|
||||
"Possible memory leak: {per_request} bytes/request retained after {num_iterations} iterations"
|
||||
);
|
||||
}
|
||||
|
||||
/// Stress test with high concurrency: many parallel determine_route calls.
|
||||
#[tokio::test]
|
||||
async fn stress_test_routing_concurrent() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
let router_url = format!("{}/v1/chat/completions", server.url());
|
||||
|
||||
let mock_response = serde_json::json!({
|
||||
"id": "chatcmpl-mock",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "arch-router",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "{\"route\": \"summarization\"}"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 10, "total_tokens": 110}
|
||||
});
|
||||
|
||||
let _mock = server
|
||||
.mock("POST", "/v1/chat/completions")
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(mock_response.to_string())
|
||||
.expect_at_least(1)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let prefs = make_routing_prefs();
|
||||
let router_service = Arc::new(RouterService::new(
|
||||
Some(prefs),
|
||||
None,
|
||||
router_url,
|
||||
"Arch-Router".to_string(),
|
||||
"arch-router".to_string(),
|
||||
));
|
||||
|
||||
// Warm up
|
||||
for _ in 0..20 {
|
||||
let msgs = make_messages(3);
|
||||
let _ = router_service
|
||||
.determine_route(&msgs, "00-trace-span-01", None, "warmup")
|
||||
.await;
|
||||
}
|
||||
|
||||
let baseline = get_allocated();
|
||||
|
||||
let concurrency = 50;
|
||||
let requests_per_task = 100;
|
||||
let total = concurrency * requests_per_task;
|
||||
|
||||
let mut handles = vec![];
|
||||
for t in 0..concurrency {
|
||||
let svc = Arc::clone(&router_service);
|
||||
let handle = tokio::spawn(async move {
|
||||
for r in 0..requests_per_task {
|
||||
let msgs = make_messages(3 + (r % 8));
|
||||
let _ = svc
|
||||
.determine_route(&msgs, "00-trace-span-01", None, &format!("req-{t}-{r}"))
|
||||
.await;
|
||||
}
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for h in handles {
|
||||
h.await.unwrap();
|
||||
}
|
||||
|
||||
let after = get_allocated();
|
||||
let growth = after.saturating_sub(baseline);
|
||||
let per_request = growth / total;
|
||||
|
||||
eprintln!("=== Concurrent Routing Stress Test Results ===");
|
||||
eprintln!(" Tasks: {concurrency} x {requests_per_task} = {total}");
|
||||
eprintln!(" Baseline: {} bytes", baseline);
|
||||
eprintln!(" Final: {} bytes", after);
|
||||
eprintln!(
|
||||
" Growth: {} bytes ({:.2} MB)",
|
||||
growth,
|
||||
growth as f64 / 1_048_576.0
|
||||
);
|
||||
eprintln!(" Per-request: {} bytes", per_request);
|
||||
|
||||
assert!(
|
||||
per_request < 512,
|
||||
"Possible memory leak under concurrency: {per_request} bytes/request retained after {total} requests"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "jemalloc")]
|
||||
fn get_allocated() -> usize {
|
||||
tikv_jemalloc_ctl::epoch::advance().unwrap();
|
||||
tikv_jemalloc_ctl::stats::allocated::read().unwrap_or(0)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "jemalloc"))]
|
||||
fn get_allocated() -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
|
@ -3,21 +3,98 @@ use async_trait::async_trait;
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, warn};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// In-memory storage backend for conversation state
|
||||
/// Uses a HashMap wrapped in Arc<RwLock<>> for thread-safe access
|
||||
const DEFAULT_TTL_SECS: u64 = 1800; // 30 minutes
|
||||
const DEFAULT_MAX_ENTRIES: usize = 10_000;
|
||||
const EVICTION_INTERVAL_SECS: u64 = 60;
|
||||
|
||||
/// In-memory storage backend for conversation state.
|
||||
///
|
||||
/// Entries are evicted when they exceed `ttl_secs` or when the store grows
|
||||
/// beyond `max_entries` (oldest-first by `created_at`). A background task
|
||||
/// runs every 60 s to sweep expired entries.
|
||||
#[derive(Clone)]
|
||||
pub struct MemoryConversationalStorage {
|
||||
storage: Arc<RwLock<HashMap<String, OpenAIConversationState>>>,
|
||||
max_entries: usize,
|
||||
}
|
||||
|
||||
impl MemoryConversationalStorage {
|
||||
pub fn new() -> Self {
|
||||
Self::with_limits(DEFAULT_TTL_SECS, DEFAULT_MAX_ENTRIES)
|
||||
}
|
||||
|
||||
pub fn with_limits(ttl_secs: u64, max_entries: usize) -> Self {
|
||||
let storage = Arc::new(RwLock::new(HashMap::new()));
|
||||
|
||||
let bg_storage = Arc::clone(&storage);
|
||||
tokio::spawn(async move {
|
||||
Self::eviction_loop(bg_storage, ttl_secs, max_entries).await;
|
||||
});
|
||||
|
||||
Self {
|
||||
storage: Arc::new(RwLock::new(HashMap::new())),
|
||||
storage,
|
||||
max_entries,
|
||||
}
|
||||
}
|
||||
|
||||
async fn eviction_loop(
|
||||
storage: Arc<RwLock<HashMap<String, OpenAIConversationState>>>,
|
||||
ttl_secs: u64,
|
||||
max_entries: usize,
|
||||
) {
|
||||
let interval = std::time::Duration::from_secs(EVICTION_INTERVAL_SECS);
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
let cutoff = now - ttl_secs as i64;
|
||||
|
||||
let mut map = storage.write().await;
|
||||
let before = map.len();
|
||||
|
||||
// Phase 1: remove expired entries
|
||||
map.retain(|_, state| state.created_at > cutoff);
|
||||
|
||||
// Phase 2: if still over capacity, drop oldest entries
|
||||
if map.len() > max_entries {
|
||||
let mut entries: Vec<(String, i64)> =
|
||||
map.iter().map(|(k, v)| (k.clone(), v.created_at)).collect();
|
||||
entries.sort_by_key(|(_, ts)| *ts);
|
||||
|
||||
let to_remove = map.len() - max_entries;
|
||||
for (key, _) in entries.into_iter().take(to_remove) {
|
||||
map.remove(&key);
|
||||
}
|
||||
}
|
||||
|
||||
let evicted = before.saturating_sub(map.len());
|
||||
if evicted > 0 {
|
||||
info!(
|
||||
evicted,
|
||||
remaining = map.len(),
|
||||
"memory state store eviction sweep"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_entry_bytes(state: &OpenAIConversationState) -> usize {
|
||||
let base = std::mem::size_of::<OpenAIConversationState>()
|
||||
+ state.response_id.len()
|
||||
+ state.model.len()
|
||||
+ state.provider.len();
|
||||
let items: usize = state
|
||||
.input_items
|
||||
.iter()
|
||||
.map(|item| {
|
||||
// Rough estimate: serialize to JSON and use its length as a proxy
|
||||
serde_json::to_string(item).map(|s| s.len()).unwrap_or(64)
|
||||
})
|
||||
.sum();
|
||||
base + items
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MemoryConversationalStorage {
|
||||
|
|
@ -38,6 +115,26 @@ impl StateStorage for MemoryConversationalStorage {
|
|||
);
|
||||
|
||||
storage.insert(response_id, state);
|
||||
|
||||
// Inline cap check so we don't wait for the background sweep
|
||||
if storage.len() > self.max_entries {
|
||||
let mut entries: Vec<(String, i64)> = storage
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), v.created_at))
|
||||
.collect();
|
||||
entries.sort_by_key(|(_, ts)| *ts);
|
||||
|
||||
let to_remove = storage.len() - self.max_entries;
|
||||
for (key, _) in entries.into_iter().take(to_remove) {
|
||||
storage.remove(&key);
|
||||
}
|
||||
info!(
|
||||
evicted = to_remove,
|
||||
remaining = storage.len(),
|
||||
"memory state store cap eviction on put"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -80,6 +177,13 @@ impl StateStorage for MemoryConversationalStorage {
|
|||
Err(StateStorageError::NotFound(response_id.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn entry_stats(&self) -> Result<(usize, usize), StateStorageError> {
|
||||
let storage = self.storage.read().await;
|
||||
let count = storage.len();
|
||||
let bytes: usize = storage.values().map(Self::estimate_entry_bytes).sum();
|
||||
Ok((count, bytes))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -635,4 +739,137 @@ mod tests {
|
|||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Stress / eviction tests
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
fn create_test_state_with_ts(
|
||||
response_id: &str,
|
||||
num_messages: usize,
|
||||
created_at: i64,
|
||||
) -> OpenAIConversationState {
|
||||
let mut state = create_test_state(response_id, num_messages);
|
||||
state.created_at = created_at;
|
||||
state
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_max_entries_cap_evicts_oldest() {
|
||||
let max = 100;
|
||||
let storage = MemoryConversationalStorage::with_limits(3600, max);
|
||||
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
for i in 0..(max + 50) {
|
||||
let state =
|
||||
create_test_state_with_ts(&format!("resp_{i}"), 2, now - (max as i64) + i as i64);
|
||||
storage.put(state).await.unwrap();
|
||||
}
|
||||
|
||||
let (count, _) = storage.entry_stats().await.unwrap();
|
||||
assert_eq!(count, max, "store should be capped at max_entries");
|
||||
|
||||
// The oldest entries should have been evicted
|
||||
assert!(
|
||||
storage.get("resp_0").await.is_err(),
|
||||
"oldest entry should be evicted"
|
||||
);
|
||||
// The newest entry should still be present
|
||||
let newest = format!("resp_{}", max + 49);
|
||||
assert!(
|
||||
storage.get(&newest).await.is_ok(),
|
||||
"newest entry should be present"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ttl_eviction_removes_expired() {
|
||||
let ttl_secs = 2;
|
||||
let storage = MemoryConversationalStorage::with_limits(ttl_secs, 100_000);
|
||||
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
|
||||
// Insert entries that are already "old" (created_at in the past beyond TTL)
|
||||
for i in 0..20 {
|
||||
let state =
|
||||
create_test_state_with_ts(&format!("old_{i}"), 1, now - (ttl_secs as i64) - 10);
|
||||
storage.put(state).await.unwrap();
|
||||
}
|
||||
|
||||
// Insert fresh entries
|
||||
for i in 0..10 {
|
||||
let state = create_test_state_with_ts(&format!("new_{i}"), 1, now);
|
||||
storage.put(state).await.unwrap();
|
||||
}
|
||||
|
||||
let (count_before, _) = storage.entry_stats().await.unwrap();
|
||||
assert_eq!(count_before, 30);
|
||||
|
||||
// Wait for the eviction sweep (interval is 60s in prod, but the old
|
||||
// entries are already past TTL so a manual trigger would clean them).
|
||||
// For unit testing we directly call the eviction logic instead of waiting.
|
||||
{
|
||||
let bg_storage = storage.storage.clone();
|
||||
let cutoff = now - ttl_secs as i64;
|
||||
let mut map = bg_storage.write().await;
|
||||
map.retain(|_, state| state.created_at > cutoff);
|
||||
}
|
||||
|
||||
let (count_after, _) = storage.entry_stats().await.unwrap();
|
||||
assert_eq!(
|
||||
count_after, 10,
|
||||
"expired entries should be evicted, only fresh ones remain"
|
||||
);
|
||||
|
||||
// Verify fresh entries are still present
|
||||
for i in 0..10 {
|
||||
assert!(storage.get(&format!("new_{i}")).await.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_entry_stats_reports_reasonable_size() {
|
||||
let storage = MemoryConversationalStorage::with_limits(3600, 10_000);
|
||||
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
for i in 0..50 {
|
||||
let state = create_test_state_with_ts(&format!("resp_{i}"), 5, now);
|
||||
storage.put(state).await.unwrap();
|
||||
}
|
||||
|
||||
let (count, bytes) = storage.entry_stats().await.unwrap();
|
||||
assert_eq!(count, 50);
|
||||
assert!(bytes > 0, "estimated bytes should be positive");
|
||||
assert!(
|
||||
bytes > 50 * 100,
|
||||
"50 entries with 5 messages each should be at least a few KB"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_high_concurrency_with_cap() {
|
||||
let max = 500;
|
||||
let storage = MemoryConversationalStorage::with_limits(3600, max);
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
|
||||
let mut handles = vec![];
|
||||
for i in 0..1000 {
|
||||
let s = storage.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
let state = create_test_state_with_ts(&format!("conc_{i}"), 3, now + i as i64);
|
||||
s.put(state).await.unwrap();
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
let (count, _) = storage.entry_stats().await.unwrap();
|
||||
assert!(
|
||||
count <= max,
|
||||
"store should never exceed max_entries, got {count}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -75,6 +75,12 @@ pub trait StateStorage: Send + Sync {
|
|||
/// Delete state for a response_id (optional, for cleanup)
|
||||
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError>;
|
||||
|
||||
/// Return (entry_count, estimated_bytes) for observability.
|
||||
/// Backends that cannot cheaply compute this may return (0, 0).
|
||||
async fn entry_stats(&self) -> Result<(usize, usize), StateStorageError> {
|
||||
Ok((0, 0))
|
||||
}
|
||||
|
||||
fn merge(
|
||||
&self,
|
||||
prev_state: &OpenAIConversationState,
|
||||
|
|
|
|||
|
|
@ -109,6 +109,8 @@ pub fn init_tracer(tracing_config: Option<&Tracing>) -> &'static SdkTracerProvid
|
|||
|
||||
let provider = SdkTracerProvider::builder()
|
||||
.with_batch_exporter(exporter)
|
||||
.with_max_attributes_per_span(64)
|
||||
.with_max_events_per_span(16)
|
||||
.build();
|
||||
|
||||
global::set_tracer_provider(provider.clone());
|
||||
|
|
|
|||
|
|
@ -123,6 +123,12 @@ pub struct StateStorageConfig {
|
|||
#[serde(rename = "type")]
|
||||
pub storage_type: StateStorageType,
|
||||
pub connection_string: Option<String>,
|
||||
/// TTL in seconds for in-memory state entries (default: 1800 = 30 min).
|
||||
/// Only applies when type is `memory`.
|
||||
pub ttl_seconds: Option<u64>,
|
||||
/// Maximum number of in-memory state entries (default: 10000).
|
||||
/// Only applies when type is `memory`.
|
||||
pub max_entries: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
|
|
|
|||
255
tests/stress/routing_stress.py
Normal file
255
tests/stress/routing_stress.py
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Stress test for Plano routing service to detect memory leaks.
|
||||
|
||||
Sends sustained traffic to the routing endpoint and monitors memory
|
||||
via the /debug/memstats and /debug/state_size endpoints.
|
||||
|
||||
Usage:
|
||||
# Against a local Plano instance (docker or native)
|
||||
python routing_stress.py --base-url http://localhost:12000
|
||||
|
||||
# Custom parameters
|
||||
python routing_stress.py \
|
||||
--base-url http://localhost:12000 \
|
||||
--num-requests 5000 \
|
||||
--concurrency 20 \
|
||||
--poll-interval 5 \
|
||||
--growth-threshold 3.0
|
||||
|
||||
Requirements:
|
||||
pip install httpx
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemSnapshot:
|
||||
timestamp: float
|
||||
allocated_bytes: int
|
||||
resident_bytes: int
|
||||
state_entries: int
|
||||
state_bytes: int
|
||||
requests_completed: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class StressResult:
|
||||
snapshots: list[MemSnapshot] = field(default_factory=list)
|
||||
total_requests: int = 0
|
||||
total_errors: int = 0
|
||||
elapsed_secs: float = 0.0
|
||||
passed: bool = True
|
||||
failure_reason: str = ""
|
||||
|
||||
|
||||
def make_routing_body(unique: bool = True) -> dict:
|
||||
"""Build a minimal chat-completions body for the routing endpoint."""
|
||||
return {
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": f"test message {uuid.uuid4() if unique else 'static'}"}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def poll_debug_endpoints(
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
requests_completed: int,
|
||||
) -> MemSnapshot | None:
|
||||
try:
|
||||
mem_resp = await client.get(f"{base_url}/debug/memstats", timeout=5)
|
||||
mem_data = mem_resp.json()
|
||||
|
||||
state_resp = await client.get(f"{base_url}/debug/state_size", timeout=5)
|
||||
state_data = state_resp.json()
|
||||
|
||||
return MemSnapshot(
|
||||
timestamp=time.time(),
|
||||
allocated_bytes=mem_data.get("allocated_bytes", 0),
|
||||
resident_bytes=mem_data.get("resident_bytes", 0),
|
||||
state_entries=state_data.get("entry_count", 0),
|
||||
state_bytes=state_data.get("estimated_bytes", 0),
|
||||
requests_completed=requests_completed,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" [warn] failed to poll debug endpoints: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
async def send_requests(
|
||||
client: httpx.AsyncClient,
|
||||
url: str,
|
||||
count: int,
|
||||
semaphore: asyncio.Semaphore,
|
||||
counter: dict,
|
||||
):
|
||||
"""Send `count` routing requests, respecting the concurrency semaphore."""
|
||||
for _ in range(count):
|
||||
async with semaphore:
|
||||
try:
|
||||
body = make_routing_body(unique=True)
|
||||
resp = await client.post(url, json=body, timeout=30)
|
||||
if resp.status_code >= 400:
|
||||
counter["errors"] += 1
|
||||
except Exception:
|
||||
counter["errors"] += 1
|
||||
finally:
|
||||
counter["completed"] += 1
|
||||
|
||||
|
||||
async def run_stress_test(
|
||||
base_url: str,
|
||||
num_requests: int,
|
||||
concurrency: int,
|
||||
poll_interval: float,
|
||||
growth_threshold: float,
|
||||
) -> StressResult:
|
||||
result = StressResult()
|
||||
routing_url = f"{base_url}/routing/v1/chat/completions"
|
||||
|
||||
print(f"Stress test config:")
|
||||
print(f" base_url: {base_url}")
|
||||
print(f" routing_url: {routing_url}")
|
||||
print(f" num_requests: {num_requests}")
|
||||
print(f" concurrency: {concurrency}")
|
||||
print(f" poll_interval: {poll_interval}s")
|
||||
print(f" growth_threshold: {growth_threshold}x")
|
||||
print()
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Take baseline snapshot
|
||||
baseline = await poll_debug_endpoints(client, base_url, 0)
|
||||
if baseline:
|
||||
result.snapshots.append(baseline)
|
||||
print(f"[baseline] allocated={baseline.allocated_bytes:,}B "
|
||||
f"resident={baseline.resident_bytes:,}B "
|
||||
f"state_entries={baseline.state_entries}")
|
||||
else:
|
||||
print("[warn] could not get baseline snapshot, continuing anyway")
|
||||
|
||||
counter = {"completed": 0, "errors": 0}
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
start = time.time()
|
||||
|
||||
# Launch request sender and poller concurrently
|
||||
sender = asyncio.create_task(
|
||||
send_requests(client, routing_url, num_requests, semaphore, counter)
|
||||
)
|
||||
|
||||
# Poll memory while requests are in flight
|
||||
while not sender.done():
|
||||
await asyncio.sleep(poll_interval)
|
||||
snapshot = await poll_debug_endpoints(client, base_url, counter["completed"])
|
||||
if snapshot:
|
||||
result.snapshots.append(snapshot)
|
||||
print(
|
||||
f" [{counter['completed']:>6}/{num_requests}] "
|
||||
f"allocated={snapshot.allocated_bytes:,}B "
|
||||
f"resident={snapshot.resident_bytes:,}B "
|
||||
f"state_entries={snapshot.state_entries} "
|
||||
f"state_bytes={snapshot.state_bytes:,}B"
|
||||
)
|
||||
|
||||
await sender
|
||||
result.elapsed_secs = time.time() - start
|
||||
result.total_requests = counter["completed"]
|
||||
result.total_errors = counter["errors"]
|
||||
|
||||
# Final snapshot
|
||||
final = await poll_debug_endpoints(client, base_url, counter["completed"])
|
||||
if final:
|
||||
result.snapshots.append(final)
|
||||
|
||||
# Analyze results
|
||||
print()
|
||||
print(f"Completed {result.total_requests} requests in {result.elapsed_secs:.1f}s "
|
||||
f"({result.total_errors} errors)")
|
||||
|
||||
if len(result.snapshots) >= 2:
|
||||
first = result.snapshots[0]
|
||||
last = result.snapshots[-1]
|
||||
|
||||
if first.resident_bytes > 0:
|
||||
growth_ratio = last.resident_bytes / first.resident_bytes
|
||||
print(f"Memory growth: {first.resident_bytes:,}B -> {last.resident_bytes:,}B "
|
||||
f"({growth_ratio:.2f}x)")
|
||||
|
||||
if growth_ratio > growth_threshold:
|
||||
result.passed = False
|
||||
result.failure_reason = (
|
||||
f"Memory grew {growth_ratio:.2f}x (threshold: {growth_threshold}x). "
|
||||
f"Likely memory leak detected."
|
||||
)
|
||||
print(f"FAIL: {result.failure_reason}")
|
||||
else:
|
||||
print(f"PASS: Memory growth {growth_ratio:.2f}x is within {growth_threshold}x threshold")
|
||||
|
||||
print(f"State store: {last.state_entries} entries, {last.state_bytes:,}B")
|
||||
else:
|
||||
print("[warn] not enough snapshots to analyze memory growth")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Plano routing service stress test")
|
||||
parser.add_argument("--base-url", default="http://localhost:12000",
|
||||
help="Base URL of the Plano instance")
|
||||
parser.add_argument("--num-requests", type=int, default=2000,
|
||||
help="Total number of requests to send")
|
||||
parser.add_argument("--concurrency", type=int, default=10,
|
||||
help="Max concurrent requests")
|
||||
parser.add_argument("--poll-interval", type=float, default=5.0,
|
||||
help="Seconds between memory polls")
|
||||
parser.add_argument("--growth-threshold", type=float, default=3.0,
|
||||
help="Max allowed memory growth ratio (fail if exceeded)")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = asyncio.run(run_stress_test(
|
||||
base_url=args.base_url,
|
||||
num_requests=args.num_requests,
|
||||
concurrency=args.concurrency,
|
||||
poll_interval=args.poll_interval,
|
||||
growth_threshold=args.growth_threshold,
|
||||
))
|
||||
|
||||
# Write results JSON for CI consumption
|
||||
report = {
|
||||
"passed": result.passed,
|
||||
"failure_reason": result.failure_reason,
|
||||
"total_requests": result.total_requests,
|
||||
"total_errors": result.total_errors,
|
||||
"elapsed_secs": result.elapsed_secs,
|
||||
"snapshots": [
|
||||
{
|
||||
"timestamp": s.timestamp,
|
||||
"allocated_bytes": s.allocated_bytes,
|
||||
"resident_bytes": s.resident_bytes,
|
||||
"state_entries": s.state_entries,
|
||||
"state_bytes": s.state_bytes,
|
||||
"requests_completed": s.requests_completed,
|
||||
}
|
||||
for s in result.snapshots
|
||||
],
|
||||
}
|
||||
print()
|
||||
print(json.dumps(report, indent=2))
|
||||
|
||||
sys.exit(0 if result.passed else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue