mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
refactor: pass AppState to handlers, add shared HTTP client, fix remaining unwraps
- Pass Arc<AppState> directly to llm_chat and agent_chat instead of destructuring into individual parameters - Add shared reqwest::Client to AppState for connection pooling on upstream LLM requests - Fix unwrap panics in pipeline.rs: get_new_session_id now returns Result, invoke_agent to_bytes properly handled - Fix unwrap panics in orchestrator.rs: strip_prefix and pop - Fix unwrap panics in response.rs: SSE parsing no longer panics - Fix unwrap panics in router services: serialization errors propagated - Convert old-style string-format debug log in state/mod.rs to structured tracing fields Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
6fcecb60c3
commit
66e55e1621
9 changed files with 87 additions and 105 deletions
|
|
@ -22,4 +22,6 @@ pub struct AppState {
|
|||
pub listeners: Arc<RwLock<Vec<Listener>>>,
|
||||
pub state_storage: Option<Arc<dyn StateStorage>>,
|
||||
pub llm_provider_url: String,
|
||||
/// Shared HTTP client for upstream LLM requests (connection pooling / keep-alive).
|
||||
pub http_client: reqwest::Client,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,10 +15,10 @@ use tracing::{debug, info, info_span, warn, Instrument};
|
|||
|
||||
use super::pipeline::{PipelineError, PipelineProcessor};
|
||||
use super::selector::{AgentSelectionError, AgentSelector};
|
||||
use crate::app_state::AppState;
|
||||
use crate::handlers::errors::build_error_chain_response;
|
||||
use crate::handlers::request::extract_request_id;
|
||||
use crate::handlers::response::ResponseHandler;
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
use crate::tracing::{operation_component, set_service_name};
|
||||
|
||||
/// Main errors for agent chat completions
|
||||
|
|
@ -38,9 +38,7 @@ pub enum AgentFilterChainError {
|
|||
|
||||
pub async fn agent_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
state: Arc<AppState>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_id = extract_request_id(&request);
|
||||
|
||||
|
|
@ -58,15 +56,7 @@ pub async fn agent_chat(
|
|||
// Set service name for orchestrator operations
|
||||
set_service_name(operation_component::ORCHESTRATOR);
|
||||
|
||||
match handle_agent_chat_inner(
|
||||
request,
|
||||
orchestrator_service,
|
||||
agents_list,
|
||||
listeners,
|
||||
request_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
match handle_agent_chat_inner(request, state, request_id).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
// Check if this is a client error from the pipeline that should be cascaded
|
||||
|
|
@ -112,13 +102,11 @@ pub async fn agent_chat(
|
|||
|
||||
async fn handle_agent_chat_inner(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
state: Arc<AppState>,
|
||||
request_id: String,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
|
||||
// Initialize services
|
||||
let agent_selector = AgentSelector::new(orchestrator_service);
|
||||
let agent_selector = AgentSelector::new(Arc::clone(&state.orchestrator_service));
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let response_handler = ResponseHandler::new();
|
||||
|
||||
|
|
@ -130,7 +118,7 @@ async fn handle_agent_chat_inner(
|
|||
|
||||
// Find the appropriate listener
|
||||
let listener: common::configuration::Listener = {
|
||||
let listeners = listeners.read().await;
|
||||
let listeners = state.listeners.read().await;
|
||||
agent_selector
|
||||
.find_listener(listener_name, &listeners)
|
||||
.await?
|
||||
|
|
@ -143,12 +131,10 @@ async fn handle_agent_chat_inner(
|
|||
info!(listener = %listener.name, "handling request");
|
||||
|
||||
// Parse request body
|
||||
let request_path = request
|
||||
.uri()
|
||||
.path()
|
||||
.to_string()
|
||||
let full_path = request.uri().path().to_string();
|
||||
let request_path = full_path
|
||||
.strip_prefix("/agents")
|
||||
.unwrap()
|
||||
.unwrap_or(&full_path)
|
||||
.to_string();
|
||||
|
||||
let request_headers = {
|
||||
|
|
@ -201,7 +187,7 @@ async fn handle_agent_chat_inner(
|
|||
|
||||
// Create agent map for pipeline processing and agent selection
|
||||
let agent_map = {
|
||||
let agents = agents_list.read().await;
|
||||
let agents = state.agents_list.read().await;
|
||||
let agents = agents.as_ref().ok_or_else(|| {
|
||||
AgentFilterChainError::RequestParsing(serde_json::Error::custom("No agents configured"))
|
||||
})?;
|
||||
|
|
@ -340,7 +326,10 @@ async fn handle_agent_chat_inner(
|
|||
);
|
||||
|
||||
// remove last message and add new one at the end
|
||||
let last_message = current_messages.pop().unwrap();
|
||||
let Some(last_message) = current_messages.pop() else {
|
||||
warn!(agent = %agent_name, "no messages in conversation history");
|
||||
break;
|
||||
};
|
||||
|
||||
// Create a new message with the agent's response as assistant message
|
||||
// and add it to the conversation history
|
||||
|
|
|
|||
|
|
@ -310,7 +310,7 @@ impl PipelineProcessor {
|
|||
let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) {
|
||||
session_id.clone()
|
||||
} else {
|
||||
let session_id = self.get_new_session_id(&agent.id, request_headers).await;
|
||||
let session_id = self.get_new_session_id(&agent.id, request_headers).await?;
|
||||
self.agent_id_session_map
|
||||
.insert(agent.id.clone(), session_id.clone());
|
||||
session_id
|
||||
|
|
@ -464,18 +464,19 @@ impl PipelineProcessor {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_new_session_id(&self, agent_id: &str, request_headers: &HeaderMap) -> String {
|
||||
async fn get_new_session_id(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<String, PipelineError> {
|
||||
info!("initializing MCP session for agent {}", agent_id);
|
||||
|
||||
let initialize_request = self.build_initialize_request();
|
||||
let headers = self
|
||||
.build_mcp_headers(request_headers, agent_id, None)
|
||||
.expect("Failed to build headers for initialization");
|
||||
let headers = self.build_mcp_headers(request_headers, agent_id, None)?;
|
||||
|
||||
let response = self
|
||||
.send_mcp_request(&initialize_request, &headers, agent_id)
|
||||
.await
|
||||
.expect("Failed to initialize MCP session");
|
||||
.await?;
|
||||
|
||||
info!("initialize response status: {}", response.status());
|
||||
|
||||
|
|
@ -483,8 +484,13 @@ impl PipelineProcessor {
|
|||
.headers()
|
||||
.get("mcp-session-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.expect("No mcp-session-id in response")
|
||||
.to_string();
|
||||
.map(|s| s.to_string())
|
||||
.ok_or_else(|| {
|
||||
PipelineError::NoContentInResponse(format!(
|
||||
"No mcp-session-id header in initialize response from agent {}",
|
||||
agent_id
|
||||
))
|
||||
})?;
|
||||
|
||||
info!(
|
||||
"created new MCP session for agent {}: {}",
|
||||
|
|
@ -493,10 +499,9 @@ impl PipelineProcessor {
|
|||
|
||||
// Send initialized notification
|
||||
self.send_initialized_notification(agent_id, &session_id, &headers)
|
||||
.await
|
||||
.expect("Failed to send initialized notification");
|
||||
.await?;
|
||||
|
||||
session_id
|
||||
Ok(session_id)
|
||||
}
|
||||
|
||||
/// Execute a HTTP-based filter agent
|
||||
|
|
@ -620,8 +625,8 @@ impl PipelineProcessor {
|
|||
|
||||
let request_url = "/v1/chat/completions";
|
||||
|
||||
let request_body = ProviderRequestType::to_bytes(&original_request).unwrap();
|
||||
// let request_body = serde_json::to_string(&request)?;
|
||||
let request_body = ProviderRequestType::to_bytes(&original_request)
|
||||
.map_err(|e| PipelineError::NoContentInResponse(e.to_string()))?;
|
||||
debug!("sending request to terminal agent {}", terminal_agent.id);
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
|
|
|
|||
|
|
@ -20,11 +20,11 @@ use tracing::{debug, info, info_span, warn, Instrument};
|
|||
|
||||
mod router;
|
||||
|
||||
use crate::app_state::AppState;
|
||||
use crate::handlers::request::extract_request_id;
|
||||
use crate::handlers::utils::{
|
||||
create_streaming_response, truncate_message, ObservableStreamProcessor,
|
||||
};
|
||||
use crate::router::llm::RouterService;
|
||||
use crate::state::response_state_processor::ResponsesStateProcessor;
|
||||
use crate::state::{
|
||||
extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
|
||||
|
|
@ -40,11 +40,7 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
|||
|
||||
pub async fn llm_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
state: Arc<AppState>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_path = request.uri().path().to_string();
|
||||
let request_headers = request.headers().clone();
|
||||
|
|
@ -64,29 +60,14 @@ pub async fn llm_chat(
|
|||
);
|
||||
|
||||
// Execute the rest of the handler inside the span
|
||||
llm_chat_inner(
|
||||
request,
|
||||
router_service,
|
||||
full_qualified_llm_provider_url,
|
||||
model_aliases,
|
||||
llm_providers,
|
||||
state_storage,
|
||||
request_id,
|
||||
request_path,
|
||||
request_headers,
|
||||
)
|
||||
.instrument(request_span)
|
||||
.await
|
||||
llm_chat_inner(request, state, request_id, request_path, request_headers)
|
||||
.instrument(request_span)
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn llm_chat_inner(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
state: Arc<AppState>,
|
||||
request_id: String,
|
||||
request_path: String,
|
||||
mut request_headers: hyper::HeaderMap,
|
||||
|
|
@ -96,14 +77,20 @@ async fn llm_chat_inner(
|
|||
|
||||
let traceparent = extract_or_generate_traceparent(&request_headers);
|
||||
|
||||
let full_qualified_llm_provider_url = format!("{}{}", state.llm_provider_url, request_path);
|
||||
|
||||
// --- Phase 1: Parse and validate the incoming request ---
|
||||
let parsed =
|
||||
match parse_and_validate_request(request, &request_path, &model_aliases, &llm_providers)
|
||||
.await
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(response) => return Ok(response),
|
||||
};
|
||||
let parsed = match parse_and_validate_request(
|
||||
request,
|
||||
&request_path,
|
||||
&state.model_aliases,
|
||||
&state.llm_providers,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(response) => return Ok(response),
|
||||
};
|
||||
|
||||
let PreparedRequest {
|
||||
mut client_request,
|
||||
|
|
@ -139,8 +126,8 @@ async fn llm_chat_inner(
|
|||
let state_ctx = match resolve_conversation_state(
|
||||
&mut client_request,
|
||||
is_responses_api_client,
|
||||
&state_storage,
|
||||
&llm_providers,
|
||||
&state.state_storage,
|
||||
&state.llm_providers,
|
||||
&alias_resolved_model,
|
||||
&request_path,
|
||||
is_streaming_request,
|
||||
|
|
@ -177,7 +164,7 @@ async fn llm_chat_inner(
|
|||
let routing_result = match async {
|
||||
set_service_name(operation_component::ROUTING);
|
||||
router_chat_get_upstream_model(
|
||||
router_service,
|
||||
Arc::clone(&state.router_service),
|
||||
client_request,
|
||||
&traceparent,
|
||||
&request_path,
|
||||
|
|
@ -207,6 +194,7 @@ async fn llm_chat_inner(
|
|||
|
||||
// --- Phase 4: Forward to upstream and stream back ---
|
||||
send_upstream(
|
||||
&state.http_client,
|
||||
&full_qualified_llm_provider_url,
|
||||
&mut request_headers,
|
||||
client_request_bytes_for_upstream,
|
||||
|
|
@ -218,7 +206,7 @@ async fn llm_chat_inner(
|
|||
is_streaming_request,
|
||||
messages_for_signals,
|
||||
state_ctx,
|
||||
state_storage,
|
||||
state.state_storage.clone(),
|
||||
request_id,
|
||||
)
|
||||
.await
|
||||
|
|
@ -458,6 +446,7 @@ async fn resolve_conversation_state(
|
|||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn send_upstream(
|
||||
http_client: &reqwest::Client,
|
||||
upstream_url: &str,
|
||||
request_headers: &mut hyper::HeaderMap,
|
||||
body: bytes::Bytes,
|
||||
|
|
@ -509,7 +498,7 @@ async fn send_upstream(
|
|||
|
||||
let request_start_time = std::time::Instant::now();
|
||||
|
||||
let llm_response = match reqwest::Client::new()
|
||||
let llm_response = match http_client
|
||||
.post(upstream_url)
|
||||
.headers(request_headers.clone())
|
||||
.body(body)
|
||||
|
|
|
|||
|
|
@ -112,7 +112,9 @@ impl ResponseHandler {
|
|||
let upstream_api =
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
let sse_iter = SseStreamIter::try_from(response_bytes.as_ref()).unwrap();
|
||||
let sse_iter = SseStreamIter::try_from(response_bytes.as_ref()).map_err(|e| {
|
||||
ResponseError::StreamError(format!("Failed to parse SSE stream: {}", e))
|
||||
})?;
|
||||
let mut accumulated_text = String::new();
|
||||
|
||||
for sse_event in sse_iter {
|
||||
|
|
@ -122,7 +124,13 @@ impl ResponseHandler {
|
|||
}
|
||||
|
||||
let transformed_event =
|
||||
SseEvent::try_from((sse_event, &client_api, &upstream_api)).unwrap();
|
||||
match SseEvent::try_from((sse_event, &client_api, &upstream_api)) {
|
||||
Ok(event) => event,
|
||||
Err(e) => {
|
||||
warn!(error = ?e, "failed to transform SSE event, skipping");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Try to get provider response and extract content delta
|
||||
match transformed_event.provider_response() {
|
||||
|
|
|
|||
|
|
@ -145,6 +145,7 @@ async fn init_app_state(
|
|||
listeners: Arc::new(RwLock::new(config.listeners.clone())),
|
||||
state_storage,
|
||||
llm_provider_url,
|
||||
http_client: reqwest::Client::new(),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -206,31 +207,18 @@ async fn route(
|
|||
stripped,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
||||
) {
|
||||
return agent_chat(
|
||||
req,
|
||||
Arc::clone(&state.orchestrator_service),
|
||||
Arc::clone(&state.agents_list),
|
||||
Arc::clone(&state.listeners),
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await;
|
||||
return agent_chat(req, Arc::clone(&state))
|
||||
.with_context(parent_cx)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// --- Standard routes ---
|
||||
match (req.method(), path.as_str()) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
||||
let url = format!("{}{}", state.llm_provider_url, path);
|
||||
llm_chat(
|
||||
req,
|
||||
Arc::clone(&state.router_service),
|
||||
url,
|
||||
Arc::clone(&state.model_aliases),
|
||||
Arc::clone(&state.llm_providers),
|
||||
state.state_storage.clone(),
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
llm_chat(req, Arc::clone(&state))
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::POST, "/function_calling") => {
|
||||
let url = format!("{}/v1/chat/completions", state.llm_provider_url);
|
||||
|
|
|
|||
|
|
@ -99,7 +99,8 @@ impl RouterService {
|
|||
"sending request to arch-router"
|
||||
);
|
||||
|
||||
let body = serde_json::to_string(&router_request).unwrap();
|
||||
let body = serde_json::to_string(&router_request)
|
||||
.map_err(super::router_model::RoutingModelError::from)?;
|
||||
debug!(body = %body, "arch router request");
|
||||
|
||||
let mut headers = header::HeaderMap::new();
|
||||
|
|
|
|||
|
|
@ -74,7 +74,8 @@ impl OrchestratorService {
|
|||
"sending request to arch-orchestrator"
|
||||
);
|
||||
|
||||
let body = serde_json::to_string(&orchestrator_request).unwrap();
|
||||
let body = serde_json::to_string(&orchestrator_request)
|
||||
.map_err(super::orchestrator_model::OrchestratorModelError::from)?;
|
||||
debug!(body = %body, "arch orchestrator request");
|
||||
|
||||
let mut headers = header::HeaderMap::new();
|
||||
|
|
|
|||
|
|
@ -88,12 +88,11 @@ pub trait StateStorage: Send + Sync {
|
|||
combined_input.extend(current_input);
|
||||
|
||||
debug!(
|
||||
"PLANO | BRIGHTSTAFF | STATE_STORAGE | RESP_ID:{} | Merged state: prev_items={}, current_items={}, total_items={}, combined_json={}",
|
||||
prev_state.response_id,
|
||||
prev_count,
|
||||
current_count,
|
||||
combined_input.len(),
|
||||
serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string())
|
||||
response_id = %prev_state.response_id,
|
||||
prev_items = prev_count,
|
||||
current_items = current_count,
|
||||
total_items = combined_input.len(),
|
||||
"merged conversation state"
|
||||
);
|
||||
|
||||
combined_input
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue