slim down to jemalloc + debug endpoint + stress tests only

This commit is contained in:
Adil Hafeez 2026-04-23 00:25:19 -07:00
parent 80a26a3289
commit 800222dc23
8 changed files with 5 additions and 567 deletions

View file

@ -477,14 +477,6 @@ properties:
connection_string:
type: string
description: Required when type is postgres. Supports environment variable substitution using $VAR or ${VAR} syntax.
ttl_seconds:
type: integer
minimum: 60
description: TTL in seconds for in-memory state entries. Only applies when type is memory. Default 1800 (30 min).
max_entries:
type: integer
minimum: 100
description: Maximum number of in-memory state entries. Only applies when type is memory. Default 10000.
additionalProperties: false
required:
- type

View file

@ -1,10 +1,8 @@
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use hyper::{Response, StatusCode};
use std::sync::Arc;
use super::full;
use crate::app_state::AppState;
#[derive(serde::Serialize)]
struct MemStats {
@ -30,7 +28,6 @@ pub async fn memstats() -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper:
fn get_jemalloc_stats() -> MemStats {
use tikv_jemalloc_ctl::{epoch, stats};
// Advance the jemalloc stats epoch so numbers are fresh.
if let Err(e) = epoch::advance() {
return MemStats {
allocated_bytes: 0,
@ -54,43 +51,3 @@ fn get_jemalloc_stats() -> MemStats {
error: Some("jemalloc feature not enabled".to_string()),
}
}
#[derive(serde::Serialize)]
struct StateSize {
entry_count: usize,
estimated_bytes: usize,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<String>,
}
/// Returns the number of entries and estimated byte size in the conversation state store.
pub async fn state_size(
state: Arc<AppState>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let result = match &state.state_storage {
Some(storage) => match storage.entry_stats().await {
Ok((count, bytes)) => StateSize {
entry_count: count,
estimated_bytes: bytes,
error: None,
},
Err(e) => StateSize {
entry_count: 0,
estimated_bytes: 0,
error: Some(format!("{e}")),
},
},
None => StateSize {
entry_count: 0,
estimated_bytes: 0,
error: Some("no state_storage configured".to_string()),
},
};
let json = serde_json::to_string(&result).unwrap();
Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(full(json))
.unwrap())
}

View file

@ -358,15 +358,11 @@ async fn init_state_storage(
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
common::configuration::StateStorageType::Memory => {
let ttl = storage_config.ttl_seconds.unwrap_or(1800);
let max = storage_config.max_entries.unwrap_or(10_000);
info!(
storage_type = "memory",
ttl_seconds = ttl,
max_entries = max,
"initialized conversation state storage"
);
Arc::new(MemoryConversationalStorage::with_limits(ttl, max))
Arc::new(MemoryConversationalStorage::new())
}
common::configuration::StateStorageType::Postgres => {
let connection_string = storage_config
@ -520,7 +516,6 @@ async fn dispatch(
}
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => cors_preflight(),
(&Method::GET, "/debug/memstats") => debug::memstats().await,
(&Method::GET, "/debug/state_size") => debug::state_size(Arc::clone(&state)).await,
_ => {
debug!(method = %req.method(), path = %path, "no route found");
let mut not_found = Response::new(empty());

View file

@ -3,98 +3,21 @@ use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use tracing::{debug, warn};
const DEFAULT_TTL_SECS: u64 = 1800; // 30 minutes
const DEFAULT_MAX_ENTRIES: usize = 10_000;
const EVICTION_INTERVAL_SECS: u64 = 60;
/// In-memory storage backend for conversation state.
///
/// Entries are evicted when they exceed `ttl_secs` or when the store grows
/// beyond `max_entries` (oldest-first by `created_at`). A background task
/// runs every 60 s to sweep expired entries.
/// In-memory storage backend for conversation state
/// Uses a HashMap wrapped in Arc<RwLock<>> for thread-safe access
#[derive(Clone)]
pub struct MemoryConversationalStorage {
storage: Arc<RwLock<HashMap<String, OpenAIConversationState>>>,
max_entries: usize,
}
impl MemoryConversationalStorage {
pub fn new() -> Self {
Self::with_limits(DEFAULT_TTL_SECS, DEFAULT_MAX_ENTRIES)
}
pub fn with_limits(ttl_secs: u64, max_entries: usize) -> Self {
let storage = Arc::new(RwLock::new(HashMap::new()));
let bg_storage = Arc::clone(&storage);
tokio::spawn(async move {
Self::eviction_loop(bg_storage, ttl_secs, max_entries).await;
});
Self {
storage,
max_entries,
storage: Arc::new(RwLock::new(HashMap::new())),
}
}
async fn eviction_loop(
storage: Arc<RwLock<HashMap<String, OpenAIConversationState>>>,
ttl_secs: u64,
max_entries: usize,
) {
let interval = std::time::Duration::from_secs(EVICTION_INTERVAL_SECS);
loop {
tokio::time::sleep(interval).await;
let now = chrono::Utc::now().timestamp();
let cutoff = now - ttl_secs as i64;
let mut map = storage.write().await;
let before = map.len();
// Phase 1: remove expired entries
map.retain(|_, state| state.created_at > cutoff);
// Phase 2: if still over capacity, drop oldest entries
if map.len() > max_entries {
let mut entries: Vec<(String, i64)> =
map.iter().map(|(k, v)| (k.clone(), v.created_at)).collect();
entries.sort_by_key(|(_, ts)| *ts);
let to_remove = map.len() - max_entries;
for (key, _) in entries.into_iter().take(to_remove) {
map.remove(&key);
}
}
let evicted = before.saturating_sub(map.len());
if evicted > 0 {
info!(
evicted,
remaining = map.len(),
"memory state store eviction sweep"
);
}
}
}
fn estimate_entry_bytes(state: &OpenAIConversationState) -> usize {
let base = std::mem::size_of::<OpenAIConversationState>()
+ state.response_id.len()
+ state.model.len()
+ state.provider.len();
let items: usize = state
.input_items
.iter()
.map(|item| {
// Rough estimate: serialize to JSON and use its length as a proxy
serde_json::to_string(item).map(|s| s.len()).unwrap_or(64)
})
.sum();
base + items
}
}
impl Default for MemoryConversationalStorage {
@ -115,26 +38,6 @@ impl StateStorage for MemoryConversationalStorage {
);
storage.insert(response_id, state);
// Inline cap check so we don't wait for the background sweep
if storage.len() > self.max_entries {
let mut entries: Vec<(String, i64)> = storage
.iter()
.map(|(k, v)| (k.clone(), v.created_at))
.collect();
entries.sort_by_key(|(_, ts)| *ts);
let to_remove = storage.len() - self.max_entries;
for (key, _) in entries.into_iter().take(to_remove) {
storage.remove(&key);
}
info!(
evicted = to_remove,
remaining = storage.len(),
"memory state store cap eviction on put"
);
}
Ok(())
}
@ -177,13 +80,6 @@ impl StateStorage for MemoryConversationalStorage {
Err(StateStorageError::NotFound(response_id.to_string()))
}
}
async fn entry_stats(&self) -> Result<(usize, usize), StateStorageError> {
let storage = self.storage.read().await;
let count = storage.len();
let bytes: usize = storage.values().map(Self::estimate_entry_bytes).sum();
Ok((count, bytes))
}
}
#[cfg(test)]
@ -739,137 +635,4 @@ mod tests {
_ => panic!("Expected MessageContent::Items"),
}
}
// -----------------------------------------------------------------------
// Stress / eviction tests
// -----------------------------------------------------------------------
fn create_test_state_with_ts(
response_id: &str,
num_messages: usize,
created_at: i64,
) -> OpenAIConversationState {
let mut state = create_test_state(response_id, num_messages);
state.created_at = created_at;
state
}
#[tokio::test]
async fn test_max_entries_cap_evicts_oldest() {
let max = 100;
let storage = MemoryConversationalStorage::with_limits(3600, max);
let now = chrono::Utc::now().timestamp();
for i in 0..(max + 50) {
let state =
create_test_state_with_ts(&format!("resp_{i}"), 2, now - (max as i64) + i as i64);
storage.put(state).await.unwrap();
}
let (count, _) = storage.entry_stats().await.unwrap();
assert_eq!(count, max, "store should be capped at max_entries");
// The oldest entries should have been evicted
assert!(
storage.get("resp_0").await.is_err(),
"oldest entry should be evicted"
);
// The newest entry should still be present
let newest = format!("resp_{}", max + 49);
assert!(
storage.get(&newest).await.is_ok(),
"newest entry should be present"
);
}
#[tokio::test]
async fn test_ttl_eviction_removes_expired() {
let ttl_secs = 2;
let storage = MemoryConversationalStorage::with_limits(ttl_secs, 100_000);
let now = chrono::Utc::now().timestamp();
// Insert entries that are already "old" (created_at in the past beyond TTL)
for i in 0..20 {
let state =
create_test_state_with_ts(&format!("old_{i}"), 1, now - (ttl_secs as i64) - 10);
storage.put(state).await.unwrap();
}
// Insert fresh entries
for i in 0..10 {
let state = create_test_state_with_ts(&format!("new_{i}"), 1, now);
storage.put(state).await.unwrap();
}
let (count_before, _) = storage.entry_stats().await.unwrap();
assert_eq!(count_before, 30);
// Wait for the eviction sweep (interval is 60s in prod, but the old
// entries are already past TTL so a manual trigger would clean them).
// For unit testing we directly call the eviction logic instead of waiting.
{
let bg_storage = storage.storage.clone();
let cutoff = now - ttl_secs as i64;
let mut map = bg_storage.write().await;
map.retain(|_, state| state.created_at > cutoff);
}
let (count_after, _) = storage.entry_stats().await.unwrap();
assert_eq!(
count_after, 10,
"expired entries should be evicted, only fresh ones remain"
);
// Verify fresh entries are still present
for i in 0..10 {
assert!(storage.get(&format!("new_{i}")).await.is_ok());
}
}
#[tokio::test]
async fn test_entry_stats_reports_reasonable_size() {
let storage = MemoryConversationalStorage::with_limits(3600, 10_000);
let now = chrono::Utc::now().timestamp();
for i in 0..50 {
let state = create_test_state_with_ts(&format!("resp_{i}"), 5, now);
storage.put(state).await.unwrap();
}
let (count, bytes) = storage.entry_stats().await.unwrap();
assert_eq!(count, 50);
assert!(bytes > 0, "estimated bytes should be positive");
assert!(
bytes > 50 * 100,
"50 entries with 5 messages each should be at least a few KB"
);
}
#[tokio::test]
async fn test_high_concurrency_with_cap() {
let max = 500;
let storage = MemoryConversationalStorage::with_limits(3600, max);
let now = chrono::Utc::now().timestamp();
let mut handles = vec![];
for i in 0..1000 {
let s = storage.clone();
let handle = tokio::spawn(async move {
let state = create_test_state_with_ts(&format!("conc_{i}"), 3, now + i as i64);
s.put(state).await.unwrap();
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
let (count, _) = storage.entry_stats().await.unwrap();
assert!(
count <= max,
"store should never exceed max_entries, got {count}"
);
}
}

View file

@ -75,12 +75,6 @@ pub trait StateStorage: Send + Sync {
/// Delete state for a response_id (optional, for cleanup)
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError>;
/// Return (entry_count, estimated_bytes) for observability.
/// Backends that cannot cheaply compute this may return (0, 0).
async fn entry_stats(&self) -> Result<(usize, usize), StateStorageError> {
Ok((0, 0))
}
fn merge(
&self,
prev_state: &OpenAIConversationState,

View file

@ -109,8 +109,6 @@ pub fn init_tracer(tracing_config: Option<&Tracing>) -> &'static SdkTracerProvid
let provider = SdkTracerProvider::builder()
.with_batch_exporter(exporter)
.with_max_attributes_per_span(64)
.with_max_events_per_span(16)
.build();
global::set_tracer_provider(provider.clone());

View file

@ -123,12 +123,6 @@ pub struct StateStorageConfig {
#[serde(rename = "type")]
pub storage_type: StateStorageType,
pub connection_string: Option<String>,
/// TTL in seconds for in-memory state entries (default: 1800 = 30 min).
/// Only applies when type is `memory`.
pub ttl_seconds: Option<u64>,
/// Maximum number of in-memory state entries (default: 10000).
/// Only applies when type is `memory`.
pub max_entries: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]

View file

@ -1,255 +0,0 @@
#!/usr/bin/env python3
"""
Stress test for Plano routing service to detect memory leaks.
Sends sustained traffic to the routing endpoint and monitors memory
via the /debug/memstats and /debug/state_size endpoints.
Usage:
# Against a local Plano instance (docker or native)
python routing_stress.py --base-url http://localhost:12000
# Custom parameters
python routing_stress.py \
--base-url http://localhost:12000 \
--num-requests 5000 \
--concurrency 20 \
--poll-interval 5 \
--growth-threshold 3.0
Requirements:
pip install httpx
"""
from __future__ import annotations
import argparse
import asyncio
import json
import sys
import time
import uuid
from dataclasses import dataclass, field
import httpx
@dataclass
class MemSnapshot:
timestamp: float
allocated_bytes: int
resident_bytes: int
state_entries: int
state_bytes: int
requests_completed: int
@dataclass
class StressResult:
snapshots: list[MemSnapshot] = field(default_factory=list)
total_requests: int = 0
total_errors: int = 0
elapsed_secs: float = 0.0
passed: bool = True
failure_reason: str = ""
def make_routing_body(unique: bool = True) -> dict:
"""Build a minimal chat-completions body for the routing endpoint."""
return {
"model": "gpt-4o",
"messages": [
{"role": "user", "content": f"test message {uuid.uuid4() if unique else 'static'}"}
],
}
async def poll_debug_endpoints(
client: httpx.AsyncClient,
base_url: str,
requests_completed: int,
) -> MemSnapshot | None:
try:
mem_resp = await client.get(f"{base_url}/debug/memstats", timeout=5)
mem_data = mem_resp.json()
state_resp = await client.get(f"{base_url}/debug/state_size", timeout=5)
state_data = state_resp.json()
return MemSnapshot(
timestamp=time.time(),
allocated_bytes=mem_data.get("allocated_bytes", 0),
resident_bytes=mem_data.get("resident_bytes", 0),
state_entries=state_data.get("entry_count", 0),
state_bytes=state_data.get("estimated_bytes", 0),
requests_completed=requests_completed,
)
except Exception as e:
print(f" [warn] failed to poll debug endpoints: {e}", file=sys.stderr)
return None
async def send_requests(
client: httpx.AsyncClient,
url: str,
count: int,
semaphore: asyncio.Semaphore,
counter: dict,
):
"""Send `count` routing requests, respecting the concurrency semaphore."""
for _ in range(count):
async with semaphore:
try:
body = make_routing_body(unique=True)
resp = await client.post(url, json=body, timeout=30)
if resp.status_code >= 400:
counter["errors"] += 1
except Exception:
counter["errors"] += 1
finally:
counter["completed"] += 1
async def run_stress_test(
base_url: str,
num_requests: int,
concurrency: int,
poll_interval: float,
growth_threshold: float,
) -> StressResult:
result = StressResult()
routing_url = f"{base_url}/routing/v1/chat/completions"
print(f"Stress test config:")
print(f" base_url: {base_url}")
print(f" routing_url: {routing_url}")
print(f" num_requests: {num_requests}")
print(f" concurrency: {concurrency}")
print(f" poll_interval: {poll_interval}s")
print(f" growth_threshold: {growth_threshold}x")
print()
async with httpx.AsyncClient() as client:
# Take baseline snapshot
baseline = await poll_debug_endpoints(client, base_url, 0)
if baseline:
result.snapshots.append(baseline)
print(f"[baseline] allocated={baseline.allocated_bytes:,}B "
f"resident={baseline.resident_bytes:,}B "
f"state_entries={baseline.state_entries}")
else:
print("[warn] could not get baseline snapshot, continuing anyway")
counter = {"completed": 0, "errors": 0}
semaphore = asyncio.Semaphore(concurrency)
start = time.time()
# Launch request sender and poller concurrently
sender = asyncio.create_task(
send_requests(client, routing_url, num_requests, semaphore, counter)
)
# Poll memory while requests are in flight
while not sender.done():
await asyncio.sleep(poll_interval)
snapshot = await poll_debug_endpoints(client, base_url, counter["completed"])
if snapshot:
result.snapshots.append(snapshot)
print(
f" [{counter['completed']:>6}/{num_requests}] "
f"allocated={snapshot.allocated_bytes:,}B "
f"resident={snapshot.resident_bytes:,}B "
f"state_entries={snapshot.state_entries} "
f"state_bytes={snapshot.state_bytes:,}B"
)
await sender
result.elapsed_secs = time.time() - start
result.total_requests = counter["completed"]
result.total_errors = counter["errors"]
# Final snapshot
final = await poll_debug_endpoints(client, base_url, counter["completed"])
if final:
result.snapshots.append(final)
# Analyze results
print()
print(f"Completed {result.total_requests} requests in {result.elapsed_secs:.1f}s "
f"({result.total_errors} errors)")
if len(result.snapshots) >= 2:
first = result.snapshots[0]
last = result.snapshots[-1]
if first.resident_bytes > 0:
growth_ratio = last.resident_bytes / first.resident_bytes
print(f"Memory growth: {first.resident_bytes:,}B -> {last.resident_bytes:,}B "
f"({growth_ratio:.2f}x)")
if growth_ratio > growth_threshold:
result.passed = False
result.failure_reason = (
f"Memory grew {growth_ratio:.2f}x (threshold: {growth_threshold}x). "
f"Likely memory leak detected."
)
print(f"FAIL: {result.failure_reason}")
else:
print(f"PASS: Memory growth {growth_ratio:.2f}x is within {growth_threshold}x threshold")
print(f"State store: {last.state_entries} entries, {last.state_bytes:,}B")
else:
print("[warn] not enough snapshots to analyze memory growth")
return result
def main():
parser = argparse.ArgumentParser(description="Plano routing service stress test")
parser.add_argument("--base-url", default="http://localhost:12000",
help="Base URL of the Plano instance")
parser.add_argument("--num-requests", type=int, default=2000,
help="Total number of requests to send")
parser.add_argument("--concurrency", type=int, default=10,
help="Max concurrent requests")
parser.add_argument("--poll-interval", type=float, default=5.0,
help="Seconds between memory polls")
parser.add_argument("--growth-threshold", type=float, default=3.0,
help="Max allowed memory growth ratio (fail if exceeded)")
args = parser.parse_args()
result = asyncio.run(run_stress_test(
base_url=args.base_url,
num_requests=args.num_requests,
concurrency=args.concurrency,
poll_interval=args.poll_interval,
growth_threshold=args.growth_threshold,
))
# Write results JSON for CI consumption
report = {
"passed": result.passed,
"failure_reason": result.failure_reason,
"total_requests": result.total_requests,
"total_errors": result.total_errors,
"elapsed_secs": result.elapsed_secs,
"snapshots": [
{
"timestamp": s.timestamp,
"allocated_bytes": s.allocated_bytes,
"resident_bytes": s.resident_bytes,
"state_entries": s.state_entries,
"state_bytes": s.state_bytes,
"requests_completed": s.requests_completed,
}
for s in result.snapshots
],
}
print()
print(json.dumps(report, indent=2))
sys.exit(0 if result.passed else 1)
if __name__ == "__main__":
main()