mirror of
https://github.com/katanemo/plano.git
synced 2026-06-11 15:05:14 +02:00
Spawn the local `claude` binary as a subprocess and expose it as an
Anthropic Messages-compatible provider. Hosted in brightstaff
(`CLAUDE_CLI_LISTEN_ADDR`), with session reuse, idle TTL, and watchdog.
User-facing surface is `model_providers: [{ model: claude-cli/* }]` —
the Python CLI auto-fills name/provider_interface/base_url/access_key
and the launcher (native + supervisord) enables the bridge listener
only when at least one claude-cli provider is present.
671 lines
25 KiB
Rust
671 lines
25 KiB
Rust
#[cfg(feature = "jemalloc")]
|
|
#[global_allocator]
|
|
static ALLOC: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
|
|
|
|
use brightstaff::app_state::AppState;
|
|
use brightstaff::handlers::agents::orchestrator::agent_chat;
|
|
use brightstaff::handlers::claude_cli::{
|
|
self, ClaudeCliConfig, SessionManager, SessionManagerConfig,
|
|
};
|
|
use brightstaff::handlers::debug;
|
|
use brightstaff::handlers::empty;
|
|
use brightstaff::handlers::function_calling::function_calling_chat_handler;
|
|
use brightstaff::handlers::llm::llm_chat;
|
|
use brightstaff::handlers::models::list_models;
|
|
use brightstaff::handlers::routing_service::routing_decision;
|
|
use brightstaff::metrics as bs_metrics;
|
|
use brightstaff::metrics::labels as metric_labels;
|
|
use brightstaff::router::model_metrics::ModelMetricsService;
|
|
use brightstaff::router::orchestrator::OrchestratorService;
|
|
use brightstaff::session_cache::init_session_cache;
|
|
use brightstaff::state::memory::MemoryConversationalStorage;
|
|
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
|
use brightstaff::state::StateStorage;
|
|
use brightstaff::tracing::init_tracer;
|
|
use bytes::Bytes;
|
|
use common::configuration::{
|
|
Agent, Configuration, FilterPipeline, ListenerType, ResolvedFilterChain,
|
|
};
|
|
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
|
|
use common::llm_providers::LlmProviders;
|
|
use http_body_util::combinators::BoxBody;
|
|
use hyper::body::Incoming;
|
|
use hyper::header::HeaderValue;
|
|
use hyper::server::conn::http1;
|
|
use hyper::service::service_fn;
|
|
use hyper::{Method, Request, Response, StatusCode};
|
|
use hyper_util::rt::TokioIo;
|
|
use opentelemetry::global;
|
|
use opentelemetry::trace::FutureExt;
|
|
use opentelemetry_http::HeaderExtractor;
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
use std::{env, fs};
|
|
use tokio::net::TcpListener;
|
|
use tokio::sync::RwLock;
|
|
use tracing::{debug, info, warn};
|
|
|
|
const BIND_ADDRESS: &str = "0.0.0.0:9091";
|
|
const DEFAULT_ORCHESTRATOR_LLM_PROVIDER: &str = "plano-orchestrator";
|
|
const DEFAULT_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator";
|
|
|
|
/// Parse a version string like `v0.4.0`, `v0.3.0`, `0.2.0` into a `(major, minor, patch)` tuple.
|
|
/// Missing parts default to 0. Non-numeric parts are treated as 0.
|
|
fn parse_semver(version: &str) -> (u32, u32, u32) {
|
|
let v = version.trim_start_matches('v');
|
|
let mut parts = v.splitn(3, '.').map(|p| p.parse::<u32>().unwrap_or(0));
|
|
let major = parts.next().unwrap_or(0);
|
|
let minor = parts.next().unwrap_or(0);
|
|
let patch = parts.next().unwrap_or(0);
|
|
(major, minor, patch)
|
|
}
|
|
|
|
/// CORS pre-flight response for the models endpoint.
|
|
fn cors_preflight() -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
|
let mut response = Response::new(empty());
|
|
*response.status_mut() = StatusCode::NO_CONTENT;
|
|
let h = response.headers_mut();
|
|
h.insert("Allow", HeaderValue::from_static("GET, OPTIONS"));
|
|
h.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
|
|
h.insert(
|
|
"Access-Control-Allow-Headers",
|
|
HeaderValue::from_static("Authorization, Content-Type"),
|
|
);
|
|
h.insert(
|
|
"Access-Control-Allow-Methods",
|
|
HeaderValue::from_static("GET, POST, OPTIONS"),
|
|
);
|
|
h.insert("Content-Type", HeaderValue::from_static("application/json"));
|
|
Ok(response)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Configuration loading
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Load and parse the YAML configuration file.
|
|
///
|
|
/// The path is read from `PLANO_CONFIG_PATH_RENDERED` (env) or falls back to
|
|
/// `./plano_config_rendered.yaml`.
|
|
fn load_config() -> Result<Configuration, Box<dyn std::error::Error + Send + Sync>> {
|
|
let path = env::var("PLANO_CONFIG_PATH_RENDERED")
|
|
.unwrap_or_else(|_| "./plano_config_rendered.yaml".to_string());
|
|
eprintln!("loading plano_config.yaml from {}", path);
|
|
|
|
let contents = fs::read_to_string(&path).map_err(|e| format!("failed to read {path}: {e}"))?;
|
|
|
|
let config: Configuration =
|
|
serde_yaml::from_str(&contents).map_err(|e| format!("failed to parse {path}: {e}"))?;
|
|
|
|
Ok(config)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Application state initialization
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Build the shared [`AppState`] from a parsed [`Configuration`].
|
|
async fn init_app_state(
|
|
config: &Configuration,
|
|
) -> Result<AppState, Box<dyn std::error::Error + Send + Sync>> {
|
|
let llm_provider_url =
|
|
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
|
|
|
// Combine agents and filters into a single list
|
|
let all_agents: Vec<Agent> = config
|
|
.agents
|
|
.as_deref()
|
|
.unwrap_or_default()
|
|
.iter()
|
|
.chain(config.filters.as_deref().unwrap_or_default())
|
|
.cloned()
|
|
.collect();
|
|
|
|
let global_agent_map: HashMap<String, Agent> = all_agents
|
|
.iter()
|
|
.map(|a| (a.id.clone(), a.clone()))
|
|
.collect();
|
|
|
|
let llm_providers = LlmProviders::try_from(config.model_providers.clone())
|
|
.map_err(|e| format!("failed to create LlmProviders: {e}"))?;
|
|
|
|
let model_listener_count = config
|
|
.listeners
|
|
.iter()
|
|
.filter(|l| l.listener_type == ListenerType::Model)
|
|
.count();
|
|
if model_listener_count > 1 {
|
|
return Err(format!(
|
|
"only one model listener is allowed, found {}",
|
|
model_listener_count
|
|
)
|
|
.into());
|
|
}
|
|
let model_listener = config
|
|
.listeners
|
|
.iter()
|
|
.find(|l| l.listener_type == ListenerType::Model);
|
|
let resolve_chain = |filter_ids: Option<Vec<String>>| -> Option<ResolvedFilterChain> {
|
|
filter_ids.map(|ids| {
|
|
let agents = ids
|
|
.iter()
|
|
.filter_map(|id| {
|
|
global_agent_map
|
|
.get(id)
|
|
.map(|a: &Agent| (id.clone(), a.clone()))
|
|
})
|
|
.collect();
|
|
ResolvedFilterChain {
|
|
filter_ids: ids,
|
|
agents,
|
|
}
|
|
})
|
|
};
|
|
let filter_pipeline = Arc::new(FilterPipeline {
|
|
input: resolve_chain(model_listener.and_then(|l| l.input_filters.clone())),
|
|
output: resolve_chain(model_listener.and_then(|l| l.output_filters.clone())),
|
|
});
|
|
|
|
let overrides = config.overrides.clone().unwrap_or_default();
|
|
|
|
let session_ttl_seconds = config.routing.as_ref().and_then(|r| r.session_ttl_seconds);
|
|
let session_cache = init_session_cache(config).await?;
|
|
|
|
// Validate that top-level routing_preferences requires v0.4.0+.
|
|
let config_version = parse_semver(&config.version);
|
|
let is_v040_plus = config_version >= (0, 4, 0);
|
|
|
|
if !is_v040_plus && config.routing_preferences.is_some() {
|
|
return Err(
|
|
"top-level routing_preferences requires version v0.4.0 or above. \
|
|
Update the version field or remove routing_preferences."
|
|
.into(),
|
|
);
|
|
}
|
|
|
|
// Validate that all models referenced in top-level routing_preferences exist in model_providers.
|
|
// The CLI renders model_providers with `name` = "openai/gpt-4o" and `model` = "gpt-4o",
|
|
// so we accept a match against either field.
|
|
if let Some(ref route_prefs) = config.routing_preferences {
|
|
let provider_model_names: std::collections::HashSet<&str> = config
|
|
.model_providers
|
|
.iter()
|
|
.flat_map(|p| std::iter::once(p.name.as_str()).chain(p.model.as_deref()))
|
|
.collect();
|
|
for pref in route_prefs {
|
|
for model in &pref.models {
|
|
if !provider_model_names.contains(model.as_str()) {
|
|
return Err(format!(
|
|
"routing_preferences route '{}' references model '{}' \
|
|
which is not declared in model_providers",
|
|
pref.name, model
|
|
)
|
|
.into());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Validate and initialize ModelMetricsService if model_metrics_sources is configured.
|
|
let metrics_service: Option<Arc<ModelMetricsService>> = if let Some(ref sources) =
|
|
config.model_metrics_sources
|
|
{
|
|
use common::configuration::MetricsSource;
|
|
let cost_count = sources
|
|
.iter()
|
|
.filter(|s| matches!(s, MetricsSource::Cost(_)))
|
|
.count();
|
|
let latency_count = sources
|
|
.iter()
|
|
.filter(|s| matches!(s, MetricsSource::Latency(_)))
|
|
.count();
|
|
if cost_count > 1 {
|
|
return Err("model_metrics_sources: only one cost metrics source is allowed".into());
|
|
}
|
|
if latency_count > 1 {
|
|
return Err("model_metrics_sources: only one latency metrics source is allowed".into());
|
|
}
|
|
let svc = ModelMetricsService::new(sources, reqwest::Client::new()).await;
|
|
Some(Arc::new(svc))
|
|
} else {
|
|
None
|
|
};
|
|
|
|
// Validate that selection_policy.prefer is compatible with the configured metric sources.
|
|
if let Some(ref prefs) = config.routing_preferences {
|
|
use common::configuration::{MetricsSource, SelectionPreference};
|
|
|
|
let has_cost_source = config
|
|
.model_metrics_sources
|
|
.as_deref()
|
|
.unwrap_or_default()
|
|
.iter()
|
|
.any(|s| matches!(s, MetricsSource::Cost(_)));
|
|
let has_latency_source = config
|
|
.model_metrics_sources
|
|
.as_deref()
|
|
.unwrap_or_default()
|
|
.iter()
|
|
.any(|s| matches!(s, MetricsSource::Latency(_)));
|
|
|
|
for pref in prefs {
|
|
if pref.selection_policy.prefer == SelectionPreference::Cheapest && !has_cost_source {
|
|
return Err(format!(
|
|
"routing_preferences route '{}' uses prefer: cheapest but no cost metrics source is configured — \
|
|
add a cost metrics source to model_metrics_sources",
|
|
pref.name
|
|
)
|
|
.into());
|
|
}
|
|
if pref.selection_policy.prefer == SelectionPreference::Fastest && !has_latency_source {
|
|
return Err(format!(
|
|
"routing_preferences route '{}' uses prefer: fastest but no latency metrics source is configured — \
|
|
add a latency metrics source to model_metrics_sources",
|
|
pref.name
|
|
)
|
|
.into());
|
|
}
|
|
}
|
|
}
|
|
|
|
// Warn about models in routing_preferences that have no matching pricing/latency data.
|
|
if let (Some(ref prefs), Some(ref svc)) = (&config.routing_preferences, &metrics_service) {
|
|
let cost_data = svc.cost_snapshot().await;
|
|
let latency_data = svc.latency_snapshot().await;
|
|
for pref in prefs {
|
|
use common::configuration::SelectionPreference;
|
|
for model in &pref.models {
|
|
let missing = match pref.selection_policy.prefer {
|
|
SelectionPreference::Cheapest => !cost_data.contains_key(model.as_str()),
|
|
SelectionPreference::Fastest => !latency_data.contains_key(model.as_str()),
|
|
_ => false,
|
|
};
|
|
if missing {
|
|
warn!(
|
|
model = %model,
|
|
route = %pref.name,
|
|
"model has no metric data — will be ranked last"
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let session_tenant_header = config
|
|
.routing
|
|
.as_ref()
|
|
.and_then(|r| r.session_cache.as_ref())
|
|
.and_then(|c| c.tenant_header.clone());
|
|
|
|
// Resolve model name: prefer llm_routing_model override, then agent_orchestration_model, then default.
|
|
let orchestrator_model_name: String = overrides
|
|
.llm_routing_model
|
|
.as_deref()
|
|
.or(overrides.agent_orchestration_model.as_deref())
|
|
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
|
|
.unwrap_or(DEFAULT_ORCHESTRATOR_MODEL_NAME)
|
|
.to_string();
|
|
|
|
let orchestrator_llm_provider: String = config
|
|
.model_providers
|
|
.iter()
|
|
.find(|p| p.model.as_deref() == Some(orchestrator_model_name.as_str()))
|
|
.map(|p| p.name.clone())
|
|
.unwrap_or_else(|| DEFAULT_ORCHESTRATOR_LLM_PROVIDER.to_string());
|
|
|
|
let orchestrator_max_tokens = overrides
|
|
.orchestrator_model_context_length
|
|
.unwrap_or(brightstaff::router::orchestrator_model_v1::MAX_TOKEN_LEN);
|
|
|
|
let orchestrator_service = Arc::new(OrchestratorService::with_routing(
|
|
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
|
orchestrator_model_name,
|
|
orchestrator_llm_provider,
|
|
config.routing_preferences.clone(),
|
|
metrics_service,
|
|
session_ttl_seconds,
|
|
session_cache,
|
|
session_tenant_header,
|
|
orchestrator_max_tokens,
|
|
));
|
|
|
|
let state_storage = init_state_storage(config).await?;
|
|
|
|
let span_attributes = config
|
|
.tracing
|
|
.as_ref()
|
|
.and_then(|tracing| tracing.span_attributes.clone());
|
|
|
|
let signals_enabled = !overrides.disable_signals.unwrap_or(false);
|
|
|
|
Ok(AppState {
|
|
orchestrator_service,
|
|
model_aliases: config.model_aliases.clone(),
|
|
llm_providers: Arc::new(RwLock::new(llm_providers)),
|
|
agents_list: Some(all_agents),
|
|
listeners: config.listeners.clone(),
|
|
state_storage,
|
|
llm_provider_url,
|
|
span_attributes,
|
|
http_client: reqwest::Client::new(),
|
|
filter_pipeline,
|
|
signals_enabled,
|
|
})
|
|
}
|
|
|
|
/// Initialize the conversation state storage backend (if configured).
|
|
async fn init_state_storage(
|
|
config: &Configuration,
|
|
) -> Result<Option<Arc<dyn StateStorage>>, Box<dyn std::error::Error + Send + Sync>> {
|
|
let Some(storage_config) = &config.state_storage else {
|
|
info!("no state_storage configured, conversation state management disabled");
|
|
return Ok(None);
|
|
};
|
|
|
|
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
|
|
common::configuration::StateStorageType::Memory => {
|
|
info!(
|
|
storage_type = "memory",
|
|
"initialized conversation state storage"
|
|
);
|
|
Arc::new(MemoryConversationalStorage::new())
|
|
}
|
|
common::configuration::StateStorageType::Postgres => {
|
|
let connection_string = storage_config
|
|
.connection_string
|
|
.as_ref()
|
|
.ok_or("connection_string is required for postgres state_storage")?;
|
|
|
|
debug!(connection_string = %connection_string, "postgres connection");
|
|
info!(
|
|
storage_type = "postgres",
|
|
"initializing conversation state storage"
|
|
);
|
|
|
|
Arc::new(
|
|
PostgreSQLConversationStorage::new(connection_string.clone())
|
|
.await
|
|
.map_err(|e| format!("failed to initialize Postgres state storage: {e}"))?,
|
|
)
|
|
}
|
|
};
|
|
|
|
Ok(Some(storage))
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Request routing
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Normalized method label — limited set so we never emit a free-form string.
|
|
fn method_label(method: &Method) -> &'static str {
|
|
match *method {
|
|
Method::GET => "GET",
|
|
Method::POST => "POST",
|
|
Method::PUT => "PUT",
|
|
Method::DELETE => "DELETE",
|
|
Method::PATCH => "PATCH",
|
|
Method::HEAD => "HEAD",
|
|
Method::OPTIONS => "OPTIONS",
|
|
_ => "OTHER",
|
|
}
|
|
}
|
|
|
|
/// Compute the fixed `handler` metric label from the request's path+method.
|
|
/// Returning `None` for fall-through means `route()` will hand the request to
|
|
/// the catch-all 404 branch.
|
|
fn handler_label_for(method: &Method, path: &str) -> &'static str {
|
|
if let Some(stripped) = path.strip_prefix("/agents") {
|
|
if matches!(
|
|
stripped,
|
|
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
|
) {
|
|
return metric_labels::HANDLER_AGENT_CHAT;
|
|
}
|
|
}
|
|
if let Some(stripped) = path.strip_prefix("/routing") {
|
|
if matches!(
|
|
stripped,
|
|
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
|
) {
|
|
return metric_labels::HANDLER_ROUTING_DECISION;
|
|
}
|
|
}
|
|
match (method, path) {
|
|
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
|
metric_labels::HANDLER_LLM_CHAT
|
|
}
|
|
(&Method::POST, "/function_calling") => metric_labels::HANDLER_FUNCTION_CALLING,
|
|
(&Method::GET, "/v1/models" | "/agents/v1/models") => metric_labels::HANDLER_LIST_MODELS,
|
|
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => {
|
|
metric_labels::HANDLER_CORS_PREFLIGHT
|
|
}
|
|
_ => metric_labels::HANDLER_NOT_FOUND,
|
|
}
|
|
}
|
|
|
|
/// Route an incoming HTTP request to the appropriate handler.
|
|
async fn route(
|
|
req: Request<Incoming>,
|
|
state: Arc<AppState>,
|
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
|
let handler = handler_label_for(req.method(), req.uri().path());
|
|
let method = method_label(req.method());
|
|
let started = std::time::Instant::now();
|
|
let _in_flight = bs_metrics::InFlightGuard::new(handler);
|
|
|
|
let result = dispatch(req, state).await;
|
|
|
|
let status = match &result {
|
|
Ok(resp) => resp.status().as_u16(),
|
|
// hyper::Error here means the body couldn't be produced; conventionally 500.
|
|
Err(_) => 500,
|
|
};
|
|
bs_metrics::record_http(handler, method, status, started);
|
|
result
|
|
}
|
|
|
|
/// Inner dispatcher split out so `route()` can wrap it with metrics without
|
|
/// duplicating the match tree.
|
|
async fn dispatch(
|
|
req: Request<Incoming>,
|
|
state: Arc<AppState>,
|
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
|
let parent_cx = global::get_text_map_propagator(|p| p.extract(&HeaderExtractor(req.headers())));
|
|
let path = req.uri().path().to_string();
|
|
|
|
// --- Agent routes (/agents/...) ---
|
|
if let Some(stripped) = path.strip_prefix("/agents") {
|
|
if matches!(
|
|
stripped,
|
|
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
|
) {
|
|
return agent_chat(req, Arc::clone(&state))
|
|
.with_context(parent_cx)
|
|
.await;
|
|
}
|
|
}
|
|
|
|
// --- Routing decision routes (/routing/...) ---
|
|
if let Some(stripped) = path.strip_prefix("/routing") {
|
|
let stripped = stripped.to_string();
|
|
if matches!(
|
|
stripped.as_str(),
|
|
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
|
) {
|
|
return routing_decision(
|
|
req,
|
|
Arc::clone(&state.orchestrator_service),
|
|
stripped,
|
|
&state.span_attributes,
|
|
)
|
|
.with_context(parent_cx)
|
|
.await;
|
|
}
|
|
}
|
|
|
|
// --- Standard routes ---
|
|
match (req.method(), path.as_str()) {
|
|
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
|
llm_chat(req, Arc::clone(&state))
|
|
.with_context(parent_cx)
|
|
.await
|
|
}
|
|
(&Method::POST, "/function_calling") => {
|
|
let url = format!("{}/v1/chat/completions", state.llm_provider_url);
|
|
function_calling_chat_handler(req, url)
|
|
.with_context(parent_cx)
|
|
.await
|
|
}
|
|
(&Method::GET, "/v1/models" | "/agents/v1/models") => {
|
|
Ok(list_models(Arc::clone(&state.llm_providers)).await)
|
|
}
|
|
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => cors_preflight(),
|
|
(&Method::GET, "/debug/memstats") => debug::memstats().await,
|
|
_ => {
|
|
debug!(method = %req.method(), path = %path, "no route found");
|
|
let mut not_found = Response::new(empty());
|
|
*not_found.status_mut() = StatusCode::NOT_FOUND;
|
|
Ok(not_found)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Server loop
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Accept connections and spawn a task per connection.
|
|
///
|
|
/// Listens for `SIGINT` / `ctrl-c` and shuts down gracefully, allowing
|
|
/// in-flight connections to finish.
|
|
async fn run_server(state: Arc<AppState>) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
|
|
let listener = TcpListener::bind(&bind_address).await?;
|
|
info!(address = %bind_address, "server listening");
|
|
|
|
let shutdown = tokio::signal::ctrl_c();
|
|
tokio::pin!(shutdown);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
result = listener.accept() => {
|
|
let (stream, _) = result?;
|
|
let peer_addr = stream.peer_addr()?;
|
|
let io = TokioIo::new(stream);
|
|
let state = Arc::clone(&state);
|
|
|
|
tokio::task::spawn(async move {
|
|
debug!(peer = ?peer_addr, "accepted connection");
|
|
|
|
let service = service_fn(move |req| {
|
|
let state = Arc::clone(&state);
|
|
async move { route(req, state).await }
|
|
});
|
|
|
|
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
|
|
warn!(error = ?err, "error serving connection");
|
|
}
|
|
});
|
|
}
|
|
_ = &mut shutdown => {
|
|
info!("received shutdown signal, stopping server");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// claude-cli bridge wiring
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Build the [`SessionManagerConfig`] from environment variables. Returns
|
|
/// `None` when `CLAUDE_CLI_LISTEN_ADDR` is unset, signaling that the bridge
|
|
/// should not start at all (zero-cost when no claude-cli provider exists).
|
|
fn claude_cli_config_from_env() -> Option<(std::net::SocketAddr, SessionManagerConfig)> {
|
|
let addr_str = env::var("CLAUDE_CLI_LISTEN_ADDR").ok()?;
|
|
let addr: std::net::SocketAddr = match addr_str.parse() {
|
|
Ok(a) => a,
|
|
Err(err) => {
|
|
warn!(
|
|
value = %addr_str,
|
|
error = %err,
|
|
"invalid CLAUDE_CLI_LISTEN_ADDR — claude-cli bridge disabled"
|
|
);
|
|
return None;
|
|
}
|
|
};
|
|
let binary = env::var("CLAUDE_CLI_BIN").unwrap_or_else(|_| "claude".to_string());
|
|
let permission_mode =
|
|
env::var("CLAUDE_CLI_PERMISSION_MODE").unwrap_or_else(|_| "bypassPermissions".to_string());
|
|
let session_ttl = env::var("CLAUDE_CLI_SESSION_TTL_SECS")
|
|
.ok()
|
|
.and_then(|s| s.parse::<u64>().ok())
|
|
.map(Duration::from_secs)
|
|
.unwrap_or_else(|| Duration::from_secs(600));
|
|
let watchdog = env::var("CLAUDE_CLI_WATCHDOG_SECS")
|
|
.ok()
|
|
.and_then(|s| s.parse::<u64>().ok())
|
|
.map(Duration::from_secs)
|
|
.unwrap_or_else(|| Duration::from_secs(120));
|
|
let max_sessions = env::var("CLAUDE_CLI_MAX_SESSIONS")
|
|
.ok()
|
|
.and_then(|s| s.parse::<usize>().ok())
|
|
.unwrap_or(claude_cli::session::DEFAULT_MAX_SESSIONS);
|
|
Some((
|
|
addr,
|
|
SessionManagerConfig {
|
|
max_sessions,
|
|
process: ClaudeCliConfig {
|
|
binary,
|
|
permission_mode,
|
|
session_ttl,
|
|
watchdog,
|
|
},
|
|
},
|
|
))
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Entry point
|
|
// ---------------------------------------------------------------------------
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
let config = load_config()?;
|
|
let _tracer_provider = init_tracer(config.tracing.as_ref());
|
|
bs_metrics::init();
|
|
info!("loaded plano_config.yaml");
|
|
let state = Arc::new(init_app_state(&config).await?);
|
|
|
|
// Optional claude-cli bridge listener. Started iff CLAUDE_CLI_LISTEN_ADDR
|
|
// is set in the environment (the Python CLI sets this when it detects a
|
|
// `model: claude-cli/*` provider entry).
|
|
let bridge_handle = if let Some((addr, cfg)) = claude_cli_config_from_env() {
|
|
let manager = SessionManager::new(cfg);
|
|
let shutdown = async {
|
|
let _ = tokio::signal::ctrl_c().await;
|
|
};
|
|
Some(tokio::spawn(async move {
|
|
if let Err(err) = claude_cli::run_listener(addr, manager, shutdown).await {
|
|
warn!(error = ?err, "claude-cli bridge listener exited with error");
|
|
}
|
|
}))
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let result = run_server(state).await;
|
|
|
|
if let Some(handle) = bridge_handle {
|
|
// Ctrl-C already triggered the bridge's own shutdown; join briefly to
|
|
// give in-flight session drains a chance to finish.
|
|
let _ = tokio::time::timeout(Duration::from_secs(5), handle).await;
|
|
}
|
|
|
|
result
|
|
}
|