mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
refactor brightstaff (#736)
This commit is contained in:
parent
1f23c573bf
commit
1ad3e0f64e
30 changed files with 1802 additions and 1700 deletions
29
crates/brightstaff/src/app_state.rs
Normal file
29
crates/brightstaff/src/app_state.rs
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::configuration::{Agent, FilterPipeline, Listener, ModelAlias, SpanAttributes};
|
||||
use common::llm_providers::LlmProviders;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::router::llm::RouterService;
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
use crate::state::StateStorage;
|
||||
|
||||
/// Shared application state bundled into a single Arc-wrapped struct.
|
||||
///
|
||||
/// Instead of cloning 8+ individual `Arc`s per connection, a single
|
||||
/// `Arc<AppState>` is cloned once and passed to the request handler.
|
||||
pub struct AppState {
|
||||
pub router_service: Arc<RouterService>,
|
||||
pub orchestrator_service: Arc<OrchestratorService>,
|
||||
pub model_aliases: Option<HashMap<String, ModelAlias>>,
|
||||
pub llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
pub agents_list: Option<Vec<Agent>>,
|
||||
pub listeners: Vec<Listener>,
|
||||
pub state_storage: Option<Arc<dyn StateStorage>>,
|
||||
pub llm_provider_url: String,
|
||||
pub span_attributes: Option<SpanAttributes>,
|
||||
/// Shared HTTP client for upstream LLM requests (connection pooling / keep-alive).
|
||||
pub http_client: reqwest::Client,
|
||||
pub filter_pipeline: Arc<FilterPipeline>,
|
||||
}
|
||||
41
crates/brightstaff/src/handlers/agents/errors.rs
Normal file
41
crates/brightstaff/src/handlers/agents/errors.rs
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
use bytes::Bytes;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use hyper::Response;
|
||||
use serde_json::json;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::handlers::response::ResponseHandler;
|
||||
|
||||
/// Build a JSON error response from an `AgentFilterChainError`, logging the
|
||||
/// full error chain along the way.
|
||||
///
|
||||
/// Returns `Ok(Response)` so it can be used directly as a handler return value.
|
||||
pub fn build_error_chain_response<E: std::error::Error>(
|
||||
err: &E,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let mut error_chain = Vec::new();
|
||||
let mut current: &dyn std::error::Error = err;
|
||||
loop {
|
||||
error_chain.push(current.to_string());
|
||||
match current.source() {
|
||||
Some(source) => current = source,
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
|
||||
warn!(error_chain = ?error_chain, "agent chat error chain");
|
||||
warn!(root_error = ?err, "root error");
|
||||
|
||||
let error_json = json!({
|
||||
"error": {
|
||||
"type": "AgentFilterChainError",
|
||||
"message": err.to_string(),
|
||||
"error_chain": error_chain,
|
||||
"debug_info": format!("{:?}", err)
|
||||
}
|
||||
});
|
||||
|
||||
info!(error = %error_json, "structured error info");
|
||||
|
||||
Ok(ResponseHandler::create_json_error_response(&error_json))
|
||||
}
|
||||
5
crates/brightstaff/src/handlers/agents/mod.rs
Normal file
5
crates/brightstaff/src/handlers/agents/mod.rs
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
pub mod errors;
|
||||
pub mod jsonrpc;
|
||||
pub mod orchestrator;
|
||||
pub mod pipeline;
|
||||
pub mod selector;
|
||||
|
|
@ -2,63 +2,56 @@ use std::sync::Arc;
|
|||
use std::time::Instant;
|
||||
|
||||
use bytes::Bytes;
|
||||
use common::configuration::SpanAttributes;
|
||||
use common::errors::BrightStaffError;
|
||||
use common::llm_providers::LlmProviders;
|
||||
use hermesllm::apis::OpenAIMessage;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
use hermesllm::providers::request::ProviderRequest;
|
||||
use hermesllm::ProviderRequestType;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use hyper::{Request, Response};
|
||||
use opentelemetry::trace::get_active_span;
|
||||
use serde::ser::Error as SerError;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
||||
use super::agent_selector::{AgentSelectionError, AgentSelector};
|
||||
use super::pipeline_processor::{PipelineError, PipelineProcessor};
|
||||
use super::response_handler::ResponseHandler;
|
||||
use crate::router::plano_orchestrator::OrchestratorService;
|
||||
use super::errors::build_error_chain_response;
|
||||
use super::pipeline::{PipelineError, PipelineProcessor};
|
||||
use super::selector::{AgentSelectionError, AgentSelector};
|
||||
use crate::app_state::AppState;
|
||||
use crate::handlers::extract_request_id;
|
||||
use crate::handlers::response::ResponseHandler;
|
||||
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
|
||||
|
||||
/// Main errors for agent chat completions
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AgentFilterChainError {
|
||||
#[error("Forwarded error: {0}")]
|
||||
Brightstaff(#[from] BrightStaffError),
|
||||
#[error("Agent selection error: {0}")]
|
||||
Selection(#[from] AgentSelectionError),
|
||||
#[error("Pipeline processing error: {0}")]
|
||||
Pipeline(#[from] PipelineError),
|
||||
#[error("Response handling error: {0}")]
|
||||
Response(#[from] common::errors::BrightStaffError),
|
||||
#[error("Request parsing error: {0}")]
|
||||
RequestParsing(#[from] serde_json::Error),
|
||||
RequestParsing(String),
|
||||
#[error("HTTP error: {0}")]
|
||||
Http(#[from] hyper::Error),
|
||||
#[error("Unsupported endpoint: {0}")]
|
||||
UnsupportedEndpoint(String),
|
||||
#[error("No agents configured")]
|
||||
NoAgentsConfigured,
|
||||
#[error("Agent '{0}' not found in configuration")]
|
||||
AgentNotFound(String),
|
||||
#[error("No messages in conversation history")]
|
||||
EmptyHistory,
|
||||
#[error("Agent chain completed without producing a response")]
|
||||
IncompleteChain,
|
||||
}
|
||||
|
||||
pub async fn agent_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
_: String,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
span_attributes: Arc<Option<SpanAttributes>>,
|
||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
state: Arc<AppState>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_id = extract_request_id(&request);
|
||||
let custom_attrs =
|
||||
collect_custom_trace_attributes(request.headers(), span_attributes.as_ref().as_ref());
|
||||
// Extract request_id from headers or generate a new one
|
||||
let request_id: String = match request
|
||||
.headers()
|
||||
.get(common::consts::REQUEST_ID_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
{
|
||||
Some(id) => id,
|
||||
None => uuid::Uuid::new_v4().to_string(),
|
||||
};
|
||||
collect_custom_trace_attributes(request.headers(), state.span_attributes.as_ref());
|
||||
|
||||
// Create a span with request_id that will be included in all log lines
|
||||
let request_span = info_span!(
|
||||
|
|
@ -74,17 +67,7 @@ pub async fn agent_chat(
|
|||
// Set service name for orchestrator operations
|
||||
set_service_name(operation_component::ORCHESTRATOR);
|
||||
|
||||
match handle_agent_chat_inner(
|
||||
request,
|
||||
orchestrator_service,
|
||||
agents_list,
|
||||
listeners,
|
||||
llm_providers,
|
||||
request_id,
|
||||
custom_attrs,
|
||||
)
|
||||
.await
|
||||
{
|
||||
match handle_agent_chat_inner(request, state, request_id, custom_attrs).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
// Check if this is a client error from the pipeline that should be cascaded
|
||||
|
|
@ -101,7 +84,6 @@ pub async fn agent_chat(
|
|||
"client error from agent"
|
||||
);
|
||||
|
||||
// Create error response with the original status code and body
|
||||
let error_json = serde_json::json!({
|
||||
"error": "ClientError",
|
||||
"agent": agent,
|
||||
|
|
@ -109,52 +91,19 @@ pub async fn agent_chat(
|
|||
"agent_response": body
|
||||
});
|
||||
|
||||
let status_code = hyper::StatusCode::from_u16(*status)
|
||||
.unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
let json_string = error_json.to_string();
|
||||
return Ok(BrightStaffError::ForwardedError {
|
||||
status_code,
|
||||
message: json_string,
|
||||
}
|
||||
.into_response());
|
||||
let mut response =
|
||||
Response::new(ResponseHandler::create_full_body(json_string));
|
||||
*response.status_mut() = hyper::StatusCode::from_u16(*status)
|
||||
.unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
response.headers_mut().insert(
|
||||
hyper::header::CONTENT_TYPE,
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
// Print detailed error information with full error chain for other errors
|
||||
let mut error_chain = Vec::new();
|
||||
let mut current_error: &dyn std::error::Error = &err;
|
||||
|
||||
// Collect the full error chain
|
||||
loop {
|
||||
error_chain.push(current_error.to_string());
|
||||
match current_error.source() {
|
||||
Some(source) => current_error = source,
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
|
||||
// Log the complete error chain
|
||||
warn!(error_chain = ?error_chain, "agent chat error chain");
|
||||
warn!(root_error = ?err, "root error");
|
||||
|
||||
// Create structured error response as JSON
|
||||
let error_json = serde_json::json!({
|
||||
"error": {
|
||||
"type": "AgentFilterChainError",
|
||||
"message": err.to_string(),
|
||||
"error_chain": error_chain,
|
||||
"debug_info": format!("{:?}", err)
|
||||
}
|
||||
});
|
||||
|
||||
// Log the error for debugging
|
||||
info!(error = %error_json, "structured error info");
|
||||
|
||||
Ok(BrightStaffError::ForwardedError {
|
||||
status_code: StatusCode::BAD_REQUEST,
|
||||
message: error_json.to_string(),
|
||||
}
|
||||
.into_response())
|
||||
build_error_chain_response(&err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -162,19 +111,22 @@ pub async fn agent_chat(
|
|||
.await
|
||||
}
|
||||
|
||||
async fn handle_agent_chat_inner(
|
||||
/// Parsed and validated agent request data.
|
||||
struct AgentRequest {
|
||||
client_request: ProviderRequestType,
|
||||
messages: Vec<OpenAIMessage>,
|
||||
request_headers: hyper::HeaderMap,
|
||||
request_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Parse the incoming HTTP request, resolve the listener, and extract messages.
|
||||
async fn parse_agent_request(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
request_id: String,
|
||||
custom_attrs: std::collections::HashMap<String, String>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
|
||||
// Initialize services
|
||||
let agent_selector = AgentSelector::new(orchestrator_service);
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let response_handler = ResponseHandler::new();
|
||||
state: &AppState,
|
||||
request_id: &str,
|
||||
custom_attrs: &std::collections::HashMap<String, String>,
|
||||
) -> Result<(AgentRequest, common::configuration::Listener, AgentSelector), AgentFilterChainError> {
|
||||
let agent_selector = AgentSelector::new(Arc::clone(&state.orchestrator_service));
|
||||
|
||||
// Extract listener name from headers
|
||||
let listener_name = request
|
||||
|
|
@ -183,16 +135,11 @@ async fn handle_agent_chat_inner(
|
|||
.and_then(|name| name.to_str().ok());
|
||||
|
||||
// Find the appropriate listener
|
||||
let listener: common::configuration::Listener = {
|
||||
let listeners = listeners.read().await;
|
||||
agent_selector
|
||||
.find_listener(listener_name, &listeners)
|
||||
.await?
|
||||
};
|
||||
let listener = agent_selector.find_listener(listener_name, &state.listeners)?;
|
||||
|
||||
get_active_span(|span| {
|
||||
span.update_name(listener.name.to_string());
|
||||
for (key, value) in &custom_attrs {
|
||||
for (key, value) in custom_attrs {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
|
||||
}
|
||||
});
|
||||
|
|
@ -200,24 +147,20 @@ async fn handle_agent_chat_inner(
|
|||
info!(listener = %listener.name, "handling request");
|
||||
|
||||
// Parse request body
|
||||
let request_path = request
|
||||
.uri()
|
||||
.path()
|
||||
.to_string()
|
||||
let full_path = request.uri().path().to_string();
|
||||
let request_path = full_path
|
||||
.strip_prefix("/agents")
|
||||
.unwrap()
|
||||
.unwrap_or(&full_path)
|
||||
.to_string();
|
||||
|
||||
let request_headers = {
|
||||
let mut headers = request.headers().clone();
|
||||
headers.remove(common::consts::ENVOY_ORIGINAL_PATH_HEADER);
|
||||
|
||||
// Set the request_id in headers if not already present
|
||||
if !headers.contains_key(common::consts::REQUEST_ID_HEADER) {
|
||||
headers.insert(
|
||||
common::consts::REQUEST_ID_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&request_id).unwrap(),
|
||||
);
|
||||
if let Ok(val) = hyper::header::HeaderValue::from_str(request_id) {
|
||||
headers.insert(common::consts::REQUEST_ID_HEADER, val);
|
||||
}
|
||||
}
|
||||
|
||||
headers
|
||||
|
|
@ -230,63 +173,62 @@ async fn handle_agent_chat_inner(
|
|||
"received request body"
|
||||
);
|
||||
|
||||
// Determine the API type from the endpoint
|
||||
let api_type =
|
||||
SupportedAPIsFromClient::from_endpoint(request_path.as_str()).ok_or_else(|| {
|
||||
let err_msg = format!("Unsupported endpoint: {}", request_path);
|
||||
warn!("{}", err_msg);
|
||||
AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg))
|
||||
warn!(path = %request_path, "unsupported endpoint");
|
||||
AgentFilterChainError::UnsupportedEndpoint(request_path.clone())
|
||||
})?;
|
||||
|
||||
let mut client_request =
|
||||
match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("failed to parse request as ProviderRequestType: {}", err);
|
||||
let err_msg = format!("Failed to parse request: {}", err);
|
||||
return Err(AgentFilterChainError::RequestParsing(
|
||||
serde_json::Error::custom(err_msg),
|
||||
));
|
||||
}
|
||||
};
|
||||
let client_request = ProviderRequestType::try_from((&chat_request_bytes[..], &api_type))
|
||||
.map_err(|err| {
|
||||
warn!(error = %err, "failed to parse request as ProviderRequestType");
|
||||
AgentFilterChainError::RequestParsing(format!("Failed to parse request: {}", err))
|
||||
})?;
|
||||
|
||||
// If model is not specified in the request, resolve from default provider
|
||||
if client_request.model().is_empty() {
|
||||
match llm_providers.read().await.default() {
|
||||
Some(default_provider) => {
|
||||
let default_model = default_provider.name.clone();
|
||||
info!(default_model = %default_model, "no model specified in request, using default provider");
|
||||
client_request.set_model(default_model);
|
||||
}
|
||||
None => {
|
||||
let err_msg = "No model specified in request and no default provider configured";
|
||||
warn!("{}", err_msg);
|
||||
return Ok(BrightStaffError::NoModelSpecified.into_response());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let message: Vec<OpenAIMessage> = client_request.get_messages();
|
||||
let messages: Vec<OpenAIMessage> = client_request.get_messages();
|
||||
|
||||
let request_id = request_headers
|
||||
.get(common::consts::REQUEST_ID_HEADER)
|
||||
.and_then(|val| val.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
// Create agent map for pipeline processing and agent selection
|
||||
let agent_map = {
|
||||
let agents = agents_list.read().await;
|
||||
let agents = agents.as_ref().unwrap();
|
||||
agent_selector.create_agent_map(agents)
|
||||
};
|
||||
Ok((
|
||||
AgentRequest {
|
||||
client_request,
|
||||
messages,
|
||||
request_headers,
|
||||
request_id,
|
||||
},
|
||||
listener,
|
||||
agent_selector,
|
||||
))
|
||||
}
|
||||
|
||||
/// Select agents via the orchestrator model and record selection metrics.
|
||||
async fn select_and_build_agent_map(
|
||||
agent_selector: &AgentSelector,
|
||||
state: &AppState,
|
||||
messages: &[OpenAIMessage],
|
||||
listener: &common::configuration::Listener,
|
||||
request_id: Option<String>,
|
||||
) -> Result<
|
||||
(
|
||||
Vec<common::configuration::AgentFilterChain>,
|
||||
std::collections::HashMap<String, common::configuration::Agent>,
|
||||
),
|
||||
AgentFilterChainError,
|
||||
> {
|
||||
let agents = state
|
||||
.agents_list
|
||||
.as_ref()
|
||||
.ok_or(AgentFilterChainError::NoAgentsConfigured)?;
|
||||
let agent_map = agent_selector.create_agent_map(agents);
|
||||
|
||||
// Select appropriate agents using arch orchestrator llm model
|
||||
let selection_start = Instant::now();
|
||||
let selected_agents = agent_selector
|
||||
.select_agents(&message, &listener, request_id.clone())
|
||||
.select_agents(messages, listener, request_id)
|
||||
.await?;
|
||||
|
||||
// Record selection attributes on the current orchestrator span
|
||||
let selection_elapsed_ms = selection_start.elapsed().as_secs_f64() * 1000.0;
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
|
|
@ -316,12 +258,25 @@ async fn handle_agent_chat_inner(
|
|||
"selected agents for execution"
|
||||
);
|
||||
|
||||
// Execute agents sequentially, passing output from one to the next
|
||||
let mut current_messages = message.clone();
|
||||
Ok((selected_agents, agent_map))
|
||||
}
|
||||
|
||||
/// Execute the agent chain: run each selected agent sequentially, streaming
|
||||
/// the final agent's response back to the client.
|
||||
async fn execute_agent_chain(
|
||||
selected_agents: &[common::configuration::AgentFilterChain],
|
||||
agent_map: &std::collections::HashMap<String, common::configuration::Agent>,
|
||||
client_request: ProviderRequestType,
|
||||
messages: Vec<OpenAIMessage>,
|
||||
request_headers: &hyper::HeaderMap,
|
||||
custom_attrs: &std::collections::HashMap<String, String>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let response_handler = ResponseHandler::new();
|
||||
let mut current_messages = messages;
|
||||
let agent_count = selected_agents.len();
|
||||
|
||||
for (agent_index, selected_agent) in selected_agents.iter().enumerate() {
|
||||
// Get agent name
|
||||
let agent_name = selected_agent.id.clone();
|
||||
let is_last_agent = agent_index == agent_count - 1;
|
||||
|
||||
|
|
@ -332,8 +287,6 @@ async fn handle_agent_chat_inner(
|
|||
"processing agent"
|
||||
);
|
||||
|
||||
// Process input filters — serialize current request as OpenAI chat completions body,
|
||||
// pass raw bytes through each filter, then extract updated messages from the result.
|
||||
let chat_history = if selected_agent
|
||||
.input_filters
|
||||
.as_ref()
|
||||
|
|
@ -351,8 +304,8 @@ async fn handle_agent_chat_inner(
|
|||
.process_raw_filter_chain(
|
||||
&filter_bytes,
|
||||
selected_agent,
|
||||
&agent_map,
|
||||
&request_headers,
|
||||
agent_map,
|
||||
request_headers,
|
||||
"/v1/chat/completions",
|
||||
)
|
||||
.await?;
|
||||
|
|
@ -365,8 +318,9 @@ async fn handle_agent_chat_inner(
|
|||
current_messages.clone()
|
||||
};
|
||||
|
||||
// Get agent details and invoke
|
||||
let agent = agent_map.get(&agent_name).unwrap();
|
||||
let agent = agent_map
|
||||
.get(&agent_name)
|
||||
.ok_or_else(|| AgentFilterChainError::AgentNotFound(agent_name.clone()))?;
|
||||
|
||||
debug!(agent = %agent_name, "invoking agent");
|
||||
|
||||
|
|
@ -380,7 +334,7 @@ async fn handle_agent_chat_inner(
|
|||
set_service_name(operation_component::AGENT);
|
||||
get_active_span(|span| {
|
||||
span.update_name(format!("{} /v1/chat/completions", agent_name));
|
||||
for (key, value) in &custom_attrs {
|
||||
for (key, value) in custom_attrs {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
|
||||
}
|
||||
});
|
||||
|
|
@ -390,28 +344,25 @@ async fn handle_agent_chat_inner(
|
|||
&chat_history,
|
||||
client_request.clone(),
|
||||
agent,
|
||||
&request_headers,
|
||||
request_headers,
|
||||
)
|
||||
.await
|
||||
}
|
||||
.instrument(agent_span.clone())
|
||||
.await?;
|
||||
|
||||
// If this is the last agent, return the streaming response
|
||||
if is_last_agent {
|
||||
info!(
|
||||
agent = %agent_name,
|
||||
"completed agent chain, returning response"
|
||||
);
|
||||
// Capture the orchestrator span (parent of the agent span) so it
|
||||
// stays open for the full streaming duration alongside the agent span.
|
||||
let orchestrator_span = tracing::Span::current();
|
||||
return async {
|
||||
response_handler
|
||||
.create_streaming_response(
|
||||
llm_response,
|
||||
tracing::Span::current(), // agent span (inner)
|
||||
orchestrator_span, // orchestrator span (outer)
|
||||
tracing::Span::current(),
|
||||
orchestrator_span,
|
||||
)
|
||||
.await
|
||||
.map_err(AgentFilterChainError::from)
|
||||
|
|
@ -420,7 +371,6 @@ async fn handle_agent_chat_inner(
|
|||
.await;
|
||||
}
|
||||
|
||||
// For intermediate agents, collect the full response and pass to next agent
|
||||
debug!(agent = %agent_name, "collecting response from intermediate agent");
|
||||
let response_text = async { response_handler.collect_full_response(llm_response).await }
|
||||
.instrument(agent_span)
|
||||
|
|
@ -432,11 +382,11 @@ async fn handle_agent_chat_inner(
|
|||
"agent completed, passing response to next agent"
|
||||
);
|
||||
|
||||
// remove last message and add new one at the end
|
||||
let last_message = current_messages.pop().unwrap();
|
||||
let Some(last_message) = current_messages.pop() else {
|
||||
warn!(agent = %agent_name, "no messages in conversation history");
|
||||
return Err(AgentFilterChainError::EmptyHistory);
|
||||
};
|
||||
|
||||
// Create a new message with the agent's response as assistant message
|
||||
// and add it to the conversation history
|
||||
current_messages.push(OpenAIMessage {
|
||||
role: hermesllm::apis::openai::Role::Assistant,
|
||||
content: Some(hermesllm::apis::openai::MessageContent::Text(response_text)),
|
||||
|
|
@ -448,6 +398,34 @@ async fn handle_agent_chat_inner(
|
|||
current_messages.push(last_message);
|
||||
}
|
||||
|
||||
// This should never be reached since we return in the last agent iteration
|
||||
unreachable!("Agent execution loop should have returned a response")
|
||||
Err(AgentFilterChainError::IncompleteChain)
|
||||
}
|
||||
|
||||
async fn handle_agent_chat_inner(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
state: Arc<AppState>,
|
||||
request_id: String,
|
||||
custom_attrs: std::collections::HashMap<String, String>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
|
||||
let (agent_req, listener, agent_selector) =
|
||||
parse_agent_request(request, &state, &request_id, &custom_attrs).await?;
|
||||
|
||||
let (selected_agents, agent_map) = select_and_build_agent_map(
|
||||
&agent_selector,
|
||||
&state,
|
||||
&agent_req.messages,
|
||||
&listener,
|
||||
agent_req.request_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
execute_agent_chain(
|
||||
&selected_agents,
|
||||
&agent_map,
|
||||
agent_req.client_request,
|
||||
agent_req.messages,
|
||||
&agent_req.request_headers,
|
||||
&custom_attrs,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
|
@ -12,7 +12,7 @@ use opentelemetry::global;
|
|||
use opentelemetry_http::HeaderInjector;
|
||||
use tracing::{debug, info, instrument, warn};
|
||||
|
||||
use crate::handlers::jsonrpc::{
|
||||
use super::jsonrpc::{
|
||||
JsonRpcId, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, JSON_RPC_VERSION,
|
||||
MCP_INITIALIZE, MCP_INITIALIZE_NOTIFICATION, TOOL_CALL_METHOD,
|
||||
};
|
||||
|
|
@ -78,12 +78,11 @@ impl PipelineProcessor {
|
|||
}
|
||||
}
|
||||
|
||||
/// Build common MCP headers for requests
|
||||
fn build_mcp_headers(
|
||||
&self,
|
||||
/// Prepare headers shared by all agent/filter requests: removes
|
||||
/// content-length, injects trace context, sets upstream host and retry.
|
||||
fn build_agent_headers(
|
||||
request_headers: &HeaderMap,
|
||||
agent_id: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> Result<HeaderMap, PipelineError> {
|
||||
let mut headers = request_headers.clone();
|
||||
headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
|
@ -104,24 +103,34 @@ impl PipelineProcessor {
|
|||
|
||||
headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
hyper::header::HeaderValue::from_static("3"),
|
||||
);
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
/// Build headers for MCP requests (adds Accept, Content-Type, optional session id).
|
||||
fn build_mcp_headers(
|
||||
&self,
|
||||
request_headers: &HeaderMap,
|
||||
agent_id: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> Result<HeaderMap, PipelineError> {
|
||||
let mut headers = Self::build_agent_headers(request_headers, agent_id)?;
|
||||
|
||||
headers.insert(
|
||||
"Accept",
|
||||
hyper::header::HeaderValue::from_static("application/json, text/event-stream"),
|
||||
);
|
||||
|
||||
headers.insert(
|
||||
"Content-Type",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
if let Some(sid) = session_id {
|
||||
headers.insert(
|
||||
"mcp-session-id",
|
||||
hyper::header::HeaderValue::from_str(sid).unwrap(),
|
||||
);
|
||||
if let Ok(val) = hyper::header::HeaderValue::from_str(sid) {
|
||||
headers.insert("mcp-session-id", val);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(headers)
|
||||
|
|
@ -243,7 +252,7 @@ impl PipelineProcessor {
|
|||
let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) {
|
||||
session_id.clone()
|
||||
} else {
|
||||
let session_id = self.get_new_session_id(&agent.id, request_headers).await;
|
||||
let session_id = self.get_new_session_id(&agent.id, request_headers).await?;
|
||||
self.agent_id_session_map
|
||||
.insert(agent.id.clone(), session_id.clone());
|
||||
session_id
|
||||
|
|
@ -394,18 +403,19 @@ impl PipelineProcessor {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_new_session_id(&self, agent_id: &str, request_headers: &HeaderMap) -> String {
|
||||
async fn get_new_session_id(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<String, PipelineError> {
|
||||
info!("initializing MCP session for agent {}", agent_id);
|
||||
|
||||
let initialize_request = self.build_initialize_request();
|
||||
let headers = self
|
||||
.build_mcp_headers(request_headers, agent_id, None)
|
||||
.expect("Failed to build headers for initialization");
|
||||
let headers = self.build_mcp_headers(request_headers, agent_id, None)?;
|
||||
|
||||
let response = self
|
||||
.send_mcp_request(&initialize_request, &headers, agent_id)
|
||||
.await
|
||||
.expect("Failed to initialize MCP session");
|
||||
.await?;
|
||||
|
||||
info!("initialize response status: {}", response.status());
|
||||
|
||||
|
|
@ -413,8 +423,13 @@ impl PipelineProcessor {
|
|||
.headers()
|
||||
.get("mcp-session-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.expect("No mcp-session-id in response")
|
||||
.to_string();
|
||||
.map(|s| s.to_string())
|
||||
.ok_or_else(|| {
|
||||
PipelineError::NoContentInResponse(format!(
|
||||
"No mcp-session-id header in initialize response from agent {}",
|
||||
agent_id
|
||||
))
|
||||
})?;
|
||||
|
||||
info!(
|
||||
"created new MCP session for agent {}: {}",
|
||||
|
|
@ -423,10 +438,9 @@ impl PipelineProcessor {
|
|||
|
||||
// Send initialized notification
|
||||
self.send_initialized_notification(agent_id, &session_id, &headers)
|
||||
.await
|
||||
.expect("Failed to send initialized notification");
|
||||
.await?;
|
||||
|
||||
session_id
|
||||
Ok(session_id)
|
||||
}
|
||||
|
||||
/// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back.
|
||||
|
|
@ -454,25 +468,7 @@ impl PipelineProcessor {
|
|||
span.update_name(format!("execute_raw_filter ({})", agent.id));
|
||||
});
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
agent_headers.remove(TRACE_PARENT_HEADER);
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
|
||||
});
|
||||
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&agent.id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
|
||||
);
|
||||
agent_headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
let mut agent_headers = Self::build_agent_headers(request_headers, &agent.id)?;
|
||||
agent_headers.insert(
|
||||
"Accept",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
|
|
@ -578,36 +574,15 @@ impl PipelineProcessor {
|
|||
terminal_agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<reqwest::Response, PipelineError> {
|
||||
// let mut request = original_request.clone();
|
||||
original_request.set_messages(messages);
|
||||
|
||||
let request_url = "/v1/chat/completions";
|
||||
|
||||
let request_body = ProviderRequestType::to_bytes(&original_request).unwrap();
|
||||
// let request_body = serde_json::to_string(&request)?;
|
||||
let request_body = ProviderRequestType::to_bytes(&original_request)
|
||||
.map_err(|e| PipelineError::NoContentInResponse(e.to_string()))?;
|
||||
debug!("sending request to terminal agent {}", terminal_agent.id);
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
// Inject OpenTelemetry trace context automatically
|
||||
agent_headers.remove(TRACE_PARENT_HEADER);
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
|
||||
});
|
||||
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&terminal_agent.id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(terminal_agent.id.clone()))?,
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
let agent_headers = Self::build_agent_headers(request_headers, &terminal_agent.id)?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
|
|
@ -7,7 +7,7 @@ use common::configuration::{
|
|||
use hermesllm::apis::openai::Message;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::router::plano_orchestrator::OrchestratorService;
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
|
||||
/// Errors that can occur during agent selection
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
|
@ -37,7 +37,7 @@ impl AgentSelector {
|
|||
}
|
||||
|
||||
/// Find listener by name from the request headers
|
||||
pub async fn find_listener(
|
||||
pub fn find_listener(
|
||||
&self,
|
||||
listener_name: Option<&str>,
|
||||
listeners: &[common::configuration::Listener],
|
||||
|
|
@ -84,7 +84,7 @@ impl AgentSelector {
|
|||
}
|
||||
|
||||
/// Convert agent descriptions to orchestration preferences
|
||||
async fn convert_agent_description_to_orchestration_preferences(
|
||||
fn convert_agent_description_to_orchestration_preferences(
|
||||
&self,
|
||||
agents: &[AgentFilterChain],
|
||||
) -> Vec<AgentUsagePreference> {
|
||||
|
|
@ -121,9 +121,7 @@ impl AgentSelector {
|
|||
return Ok(vec![agents[0].clone()]);
|
||||
}
|
||||
|
||||
let usage_preferences = self
|
||||
.convert_agent_description_to_orchestration_preferences(agents)
|
||||
.await;
|
||||
let usage_preferences = self.convert_agent_description_to_orchestration_preferences(agents);
|
||||
debug!(
|
||||
"Agents usage preferences for orchestration: {}",
|
||||
serde_json::to_string(&usage_preferences).unwrap_or_default()
|
||||
|
|
@ -222,9 +220,7 @@ mod tests {
|
|||
let listener2 = create_test_listener("other-listener", vec![]);
|
||||
let listeners = vec![listener1.clone(), listener2];
|
||||
|
||||
let result = selector
|
||||
.find_listener(Some("test-listener"), &listeners)
|
||||
.await;
|
||||
let result = selector.find_listener(Some("test-listener"), &listeners);
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().name, "test-listener");
|
||||
|
|
@ -237,9 +233,7 @@ mod tests {
|
|||
|
||||
let listeners = vec![create_test_listener("other-listener", vec![])];
|
||||
|
||||
let result = selector
|
||||
.find_listener(Some("nonexistent"), &listeners)
|
||||
.await;
|
||||
let result = selector.find_listener(Some("nonexistent"), &listeners);
|
||||
|
||||
assert!(result.is_err());
|
||||
matches!(
|
||||
|
|
@ -3,12 +3,11 @@ use std::sync::Arc;
|
|||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
use hyper::header::HeaderMap;
|
||||
|
||||
use crate::handlers::agent_selector::{AgentSelectionError, AgentSelector};
|
||||
use crate::handlers::pipeline_processor::PipelineProcessor;
|
||||
use crate::router::plano_orchestrator::OrchestratorService;
|
||||
use common::errors::BrightStaffError;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::StatusCode;
|
||||
use crate::handlers::agents::pipeline::PipelineProcessor;
|
||||
use crate::handlers::agents::selector::{AgentSelectionError, AgentSelector};
|
||||
use crate::handlers::response::ResponseHandler;
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
|
||||
/// Integration test that demonstrates the modular agent chat flow
|
||||
/// This test shows how the three main components work together:
|
||||
/// 1. AgentSelector - selects the appropriate agents based on orchestration
|
||||
|
|
@ -86,9 +85,7 @@ mod tests {
|
|||
let messages = vec![create_test_message(Role::User, "Hello world!")];
|
||||
|
||||
// Test 1: Agent Selection
|
||||
let selected_listener = agent_selector
|
||||
.find_listener(Some("test-listener"), &listeners)
|
||||
.await;
|
||||
let selected_listener = agent_selector.find_listener(Some("test-listener"), &listeners);
|
||||
|
||||
assert!(selected_listener.is_ok());
|
||||
let listener = selected_listener.unwrap();
|
||||
|
|
@ -142,24 +139,8 @@ mod tests {
|
|||
}
|
||||
|
||||
// Test 4: Error Response Creation
|
||||
let err = BrightStaffError::ModelNotFound("gpt-5-secret".to_string());
|
||||
let response = err.into_response();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
// Helper to extract body as JSON
|
||||
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
|
||||
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||
|
||||
assert_eq!(body["error"]["code"], "ModelNotFound");
|
||||
assert_eq!(
|
||||
body["error"]["details"]["rejected_model_id"],
|
||||
"gpt-5-secret"
|
||||
);
|
||||
assert!(body["error"]["message"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.contains("gpt-5-secret"));
|
||||
let error_response = ResponseHandler::create_bad_request("Test error");
|
||||
assert_eq!(error_response.status(), hyper::StatusCode::BAD_REQUEST);
|
||||
|
||||
println!("✅ All modular components working correctly!");
|
||||
}
|
||||
|
|
@ -170,7 +151,7 @@ mod tests {
|
|||
let agent_selector = AgentSelector::new(router_service);
|
||||
|
||||
// Test listener not found
|
||||
let result = agent_selector.find_listener(Some("nonexistent"), &[]).await;
|
||||
let result = agent_selector.find_listener(Some("nonexistent"), &[]);
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
|
|
@ -178,21 +159,12 @@ mod tests {
|
|||
AgentSelectionError::ListenerNotFound(_)
|
||||
));
|
||||
|
||||
let technical_reason = "Database connection timed out";
|
||||
let err = BrightStaffError::InternalServerError(technical_reason.to_string());
|
||||
|
||||
let response = err.into_response();
|
||||
|
||||
// --- 1. EXTRACT BYTES ---
|
||||
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
|
||||
|
||||
// --- 2. DECLARE body_json HERE ---
|
||||
let body_json: serde_json::Value =
|
||||
serde_json::from_slice(&body_bytes).expect("Failed to parse JSON body");
|
||||
|
||||
// --- 3. USE body_json ---
|
||||
assert_eq!(body_json["error"]["code"], "InternalServerError");
|
||||
assert_eq!(body_json["error"]["details"]["reason"], technical_reason);
|
||||
// Test error response creation
|
||||
let error_response = ResponseHandler::create_internal_error("Pipeline failed");
|
||||
assert_eq!(
|
||||
error_response.status(),
|
||||
hyper::StatusCode::INTERNAL_SERVER_ERROR
|
||||
);
|
||||
|
||||
println!("✅ Error handling working correctly!");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,686 +0,0 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{FilterPipeline, ModelAlias, SpanAttributes};
|
||||
use common::consts::{
|
||||
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||
};
|
||||
use common::llm_providers::LlmProviders;
|
||||
use hermesllm::apis::openai_responses::InputParam;
|
||||
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use opentelemetry::global;
|
||||
use opentelemetry::trace::get_active_span;
|
||||
use opentelemetry_http::HeaderInjector;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
||||
use super::pipeline_processor::PipelineProcessor;
|
||||
|
||||
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
||||
use crate::handlers::streaming::{
|
||||
create_streaming_response, create_streaming_response_with_output_filter, truncate_message,
|
||||
ObservableStreamProcessor, StreamProcessor,
|
||||
};
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::state::response_state_processor::ResponsesStateProcessor;
|
||||
use crate::state::{
|
||||
extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
|
||||
};
|
||||
use crate::tracing::{
|
||||
collect_custom_trace_attributes, llm as tracing_llm, operation_component, set_service_name,
|
||||
};
|
||||
|
||||
use common::errors::BrightStaffError;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn llm_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
span_attributes: Arc<Option<SpanAttributes>>,
|
||||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
filter_pipeline: Arc<FilterPipeline>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_path = request.uri().path().to_string();
|
||||
let request_headers = request.headers().clone();
|
||||
let request_id: String = match request_headers
|
||||
.get(REQUEST_ID_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
{
|
||||
Some(id) => id,
|
||||
None => uuid::Uuid::new_v4().to_string(),
|
||||
};
|
||||
let custom_attrs =
|
||||
collect_custom_trace_attributes(&request_headers, span_attributes.as_ref().as_ref());
|
||||
|
||||
// Create a span with request_id that will be included in all log lines
|
||||
let request_span = info_span!(
|
||||
"llm",
|
||||
component = "llm",
|
||||
request_id = %request_id,
|
||||
http.method = %request.method(),
|
||||
http.path = %request_path,
|
||||
llm.model = tracing::field::Empty,
|
||||
llm.tools = tracing::field::Empty,
|
||||
llm.user_message_preview = tracing::field::Empty,
|
||||
llm.temperature = tracing::field::Empty,
|
||||
);
|
||||
|
||||
// Execute the rest of the handler inside the span
|
||||
llm_chat_inner(
|
||||
request,
|
||||
router_service,
|
||||
full_qualified_llm_provider_url,
|
||||
model_aliases,
|
||||
llm_providers,
|
||||
custom_attrs,
|
||||
state_storage,
|
||||
request_id,
|
||||
request_path,
|
||||
request_headers,
|
||||
filter_pipeline,
|
||||
)
|
||||
.instrument(request_span)
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn llm_chat_inner(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
custom_attrs: HashMap<String, String>,
|
||||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
request_id: String,
|
||||
request_path: String,
|
||||
mut request_headers: hyper::HeaderMap,
|
||||
filter_pipeline: Arc<FilterPipeline>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
// Set service name for LLM operations
|
||||
set_service_name(operation_component::LLM);
|
||||
get_active_span(|span| {
|
||||
for (key, value) in &custom_attrs {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
|
||||
}
|
||||
});
|
||||
|
||||
// Extract or generate traceparent - this establishes the trace context for all spans
|
||||
let traceparent: String = match request_headers
|
||||
.get(TRACE_PARENT_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
{
|
||||
Some(tp) => tp,
|
||||
None => {
|
||||
use uuid::Uuid;
|
||||
let trace_id = Uuid::new_v4().to_string().replace("-", "");
|
||||
let generated_tp = format!("00-{}-0000000000000000-01", trace_id);
|
||||
warn!(
|
||||
generated_traceparent = %generated_tp,
|
||||
"TRACE_PARENT header missing, generated new traceparent"
|
||||
);
|
||||
generated_tp
|
||||
}
|
||||
};
|
||||
|
||||
let raw_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!(
|
||||
body = %String::from_utf8_lossy(&raw_bytes),
|
||||
"request body received"
|
||||
);
|
||||
|
||||
// Extract routing_policy from request body if present
|
||||
let (chat_request_bytes, inline_routing_policy) =
|
||||
match crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false) {
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to parse request JSON");
|
||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||
"Failed to parse request: {}",
|
||||
err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
let mut client_request = match ProviderRequestType::try_from((
|
||||
&chat_request_bytes[..],
|
||||
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
|
||||
)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
error = %err,
|
||||
"failed to parse request as ProviderRequestType"
|
||||
);
|
||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||
"Failed to parse request: {}",
|
||||
err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
// === v1/responses state management: Extract input items early ===
|
||||
let mut original_input_items = Vec::new();
|
||||
let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str());
|
||||
let is_responses_api_client = matches!(
|
||||
client_api,
|
||||
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))
|
||||
);
|
||||
|
||||
// If model is not specified in the request, resolve from default provider
|
||||
let model_from_request = client_request.model().to_string();
|
||||
let model_from_request = if model_from_request.is_empty() {
|
||||
match llm_providers.read().await.default() {
|
||||
Some(default_provider) => {
|
||||
let default_model = default_provider.name.clone();
|
||||
info!(default_model = %default_model, "no model specified in request, using default provider");
|
||||
client_request.set_model(default_model.clone());
|
||||
default_model
|
||||
}
|
||||
None => {
|
||||
let err_msg = "No model specified in request and no default provider configured";
|
||||
warn!("{}", err_msg);
|
||||
return Ok(BrightStaffError::NoModelSpecified.into_response());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
model_from_request
|
||||
};
|
||||
|
||||
// Model alias resolution: update model field in client_request immediately
|
||||
// This ensures all downstream objects use the resolved model
|
||||
let temperature = client_request.get_temperature();
|
||||
let is_streaming_request = client_request.is_streaming();
|
||||
let alias_resolved_model = resolve_model_alias(&model_from_request, &model_aliases);
|
||||
let (provider_id, _) = get_provider_info(&llm_providers, &alias_resolved_model).await;
|
||||
|
||||
// Validate that the requested model exists in configuration
|
||||
// This matches the validation in llm_gateway routing.rs
|
||||
if llm_providers
|
||||
.read()
|
||||
.await
|
||||
.get(&alias_resolved_model)
|
||||
.is_none()
|
||||
{
|
||||
warn!(model = %alias_resolved_model, "model not found in configured providers");
|
||||
return Ok(BrightStaffError::ModelNotFound(alias_resolved_model).into_response());
|
||||
}
|
||||
|
||||
// Handle provider/model slug format (e.g., "openai/gpt-4")
|
||||
// Extract just the model name for upstream (providers don't understand the slug)
|
||||
let model_name_only = if let Some((_, model)) = alias_resolved_model.split_once('/') {
|
||||
model.to_string()
|
||||
} else {
|
||||
alias_resolved_model.clone()
|
||||
};
|
||||
|
||||
// Extract tool names and user message preview for span attributes
|
||||
let tool_names = client_request.get_tool_names();
|
||||
let user_message_preview = client_request
|
||||
.get_recent_user_message()
|
||||
.map(|msg| truncate_message(&msg, 50));
|
||||
let span = tracing::Span::current();
|
||||
if let Some(temp) = temperature {
|
||||
span.record(tracing_llm::TEMPERATURE, tracing::field::display(temp));
|
||||
}
|
||||
if let Some(tools) = &tool_names {
|
||||
let formatted_tools = tools
|
||||
.iter()
|
||||
.map(|name| format!("{}(...)", name))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
span.record(tracing_llm::TOOLS, formatted_tools.as_str());
|
||||
}
|
||||
if let Some(preview) = &user_message_preview {
|
||||
span.record(tracing_llm::USER_MESSAGE_PREVIEW, preview.as_str());
|
||||
}
|
||||
|
||||
// Extract messages for signal analysis (clone before moving client_request)
|
||||
let messages_for_signals = Some(client_request.get_messages());
|
||||
|
||||
// Set the model to just the model name (without provider prefix)
|
||||
// This ensures upstream receives "gpt-4" not "openai/gpt-4"
|
||||
client_request.set_model(model_name_only.clone());
|
||||
if client_request.remove_metadata_key("plano_preference_config") {
|
||||
debug!("removed plano_preference_config from metadata");
|
||||
}
|
||||
|
||||
// === Input filters processing for model listener ===
|
||||
// Filters receive the raw request bytes and return (possibly modified) raw bytes.
|
||||
// The returned bytes are re-parsed into a ProviderRequestType to continue the request.
|
||||
{
|
||||
if let Some(ref input_chain) = filter_pipeline.input {
|
||||
if !input_chain.is_empty() {
|
||||
debug!(input_filters = ?input_chain.filter_ids, "processing model listener input filters");
|
||||
|
||||
let chain = input_chain.to_agent_filter_chain("model_listener");
|
||||
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
match pipeline_processor
|
||||
.process_raw_filter_chain(
|
||||
&chat_request_bytes,
|
||||
&chain,
|
||||
&input_chain.agents,
|
||||
&request_headers,
|
||||
&request_path,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(filtered_bytes) => {
|
||||
match ProviderRequestType::try_from((
|
||||
&filtered_bytes[..],
|
||||
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
|
||||
)) {
|
||||
Ok(updated_request) => {
|
||||
client_request = updated_request;
|
||||
info!("input filter chain processed successfully");
|
||||
}
|
||||
Err(parse_err) => {
|
||||
warn!(error = %parse_err, "input filter returned invalid request JSON");
|
||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||
"Input filter returned invalid request: {}",
|
||||
parse_err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(super::pipeline_processor::PipelineError::ClientError {
|
||||
agent,
|
||||
status,
|
||||
body,
|
||||
}) => {
|
||||
warn!(
|
||||
agent = %agent,
|
||||
status = %status,
|
||||
body = %body,
|
||||
"client error from filter chain"
|
||||
);
|
||||
let error_json = serde_json::json!({
|
||||
"error": "FilterChainError",
|
||||
"agent": agent,
|
||||
"status": status,
|
||||
"agent_response": body
|
||||
});
|
||||
let mut error_response = Response::new(full(error_json.to_string()));
|
||||
*error_response.status_mut() =
|
||||
StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_REQUEST);
|
||||
error_response.headers_mut().insert(
|
||||
hyper::header::CONTENT_TYPE,
|
||||
"application/json".parse().unwrap(),
|
||||
);
|
||||
return Ok(error_response);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(error = %err, "filter chain processing failed");
|
||||
let err_msg = format!("Filter chain processing failed: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref client_api_kind) = client_api {
|
||||
let upstream_api =
|
||||
provider_id.compatible_api_for_client(client_api_kind, is_streaming_request);
|
||||
client_request.normalize_for_upstream(provider_id, &upstream_api);
|
||||
}
|
||||
// === v1/responses state management: Determine upstream API and combine input if needed ===
|
||||
// Do this BEFORE routing since routing consumes the request
|
||||
// Only process state if state_storage is configured
|
||||
let mut should_manage_state = false;
|
||||
if is_responses_api_client {
|
||||
if let (
|
||||
ProviderRequestType::ResponsesAPIRequest(ref mut responses_req),
|
||||
Some(ref state_store),
|
||||
) = (&mut client_request, &state_storage)
|
||||
{
|
||||
// Extract original input once
|
||||
original_input_items = extract_input_items(&responses_req.input);
|
||||
|
||||
// Get the upstream path and check if it's ResponsesAPI
|
||||
let upstream_path = get_upstream_path(
|
||||
&llm_providers,
|
||||
&alias_resolved_model,
|
||||
&request_path,
|
||||
&alias_resolved_model,
|
||||
is_streaming_request,
|
||||
)
|
||||
.await;
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
|
||||
|
||||
// Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation)
|
||||
should_manage_state = !matches!(
|
||||
upstream_api,
|
||||
Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))
|
||||
);
|
||||
|
||||
if should_manage_state {
|
||||
// Retrieve and combine conversation history if previous_response_id exists
|
||||
if let Some(ref prev_resp_id) = responses_req.previous_response_id {
|
||||
match retrieve_and_combine_input(
|
||||
state_store.clone(),
|
||||
prev_resp_id,
|
||||
original_input_items, // Pass ownership instead of cloning
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(combined_input) => {
|
||||
// Update both the request and original_input_items
|
||||
responses_req.input = InputParam::Items(combined_input.clone());
|
||||
original_input_items = combined_input;
|
||||
info!(
|
||||
items = original_input_items.len(),
|
||||
"updated request with conversation history"
|
||||
);
|
||||
}
|
||||
Err(StateStorageError::NotFound(_)) => {
|
||||
// Return 409 Conflict when previous_response_id not found
|
||||
warn!(previous_response_id = %prev_resp_id, "previous response_id not found");
|
||||
return Ok(BrightStaffError::ConversationStateNotFound(
|
||||
prev_resp_id.to_string(),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
Err(e) => {
|
||||
// Log warning but continue on other storage errors
|
||||
warn!(
|
||||
previous_response_id = %prev_resp_id,
|
||||
error = %e,
|
||||
"failed to retrieve conversation state"
|
||||
);
|
||||
// Restore original_input_items since we passed ownership
|
||||
original_input_items = extract_input_items(&responses_req.input);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("upstream supports ResponsesAPI natively");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize request for upstream BEFORE router consumes it
|
||||
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
|
||||
|
||||
// Determine routing using the dedicated router_chat module
|
||||
// This gets its own span for latency and error tracking
|
||||
let routing_span = info_span!(
|
||||
"routing",
|
||||
component = "routing",
|
||||
http.method = "POST",
|
||||
http.target = %request_path,
|
||||
model.requested = %model_from_request,
|
||||
model.alias_resolved = %alias_resolved_model,
|
||||
route.selected_model = tracing::field::Empty,
|
||||
routing.determination_ms = tracing::field::Empty,
|
||||
);
|
||||
let routing_result = match async {
|
||||
set_service_name(operation_component::ROUTING);
|
||||
router_chat_get_upstream_model(
|
||||
router_service,
|
||||
client_request, // Pass the original request - router_chat will convert it
|
||||
&traceparent,
|
||||
&request_path,
|
||||
&request_id,
|
||||
inline_routing_policy,
|
||||
)
|
||||
.await
|
||||
}
|
||||
.instrument(routing_span)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
return Ok(BrightStaffError::ForwardedError {
|
||||
status_code: err.status_code,
|
||||
message: err.message,
|
||||
}
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
// Determine final model to use
|
||||
// Router returns "none" as a sentinel value when it doesn't select a specific model
|
||||
let router_selected_model = routing_result.model_name;
|
||||
let resolved_model = if router_selected_model != "none" {
|
||||
// Router selected a specific model via routing preferences
|
||||
router_selected_model
|
||||
} else {
|
||||
// Router returned "none" sentinel, use validated resolved_model from request
|
||||
alias_resolved_model.clone()
|
||||
};
|
||||
tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str());
|
||||
|
||||
let span_name = if model_from_request == resolved_model {
|
||||
format!("POST {} {}", request_path, resolved_model)
|
||||
} else {
|
||||
format!(
|
||||
"POST {} {} -> {}",
|
||||
request_path, model_from_request, resolved_model
|
||||
)
|
||||
};
|
||||
get_active_span(|span| {
|
||||
span.update_name(span_name.clone());
|
||||
});
|
||||
|
||||
debug!(
|
||||
url = %full_qualified_llm_provider_url,
|
||||
provider_hint = %resolved_model,
|
||||
upstream_model = %model_name_only,
|
||||
"Routing to upstream"
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
ARCH_PROVIDER_HINT_HEADER,
|
||||
header::HeaderValue::from_str(&resolved_model).unwrap(),
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
|
||||
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
|
||||
);
|
||||
// remove content-length header if it exists
|
||||
request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
// Inject current LLM span's trace context so upstream spans are children of plano(llm)
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx = tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
|
||||
});
|
||||
|
||||
// Output filters run for any API shape that reaches this handler (e.g. /v1/chat/completions,
|
||||
// /v1/messages, /v1/responses). Brightstaff does inbound translation and llm_gateway does
|
||||
// outbound translation; filters receive raw response bytes and request path.
|
||||
let has_output_filter = filter_pipeline.has_output_filters();
|
||||
|
||||
// Save request headers for output filters (before they're consumed by upstream request)
|
||||
let output_filter_request_headers = if has_output_filter {
|
||||
Some(request_headers.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Capture start time right before sending request to upstream
|
||||
let request_start_time = std::time::Instant::now();
|
||||
let _request_start_system_time = std::time::SystemTime::now();
|
||||
|
||||
let llm_response = match reqwest::Client::new()
|
||||
.post(&full_qualified_llm_provider_url)
|
||||
.headers(request_headers)
|
||||
.body(client_request_bytes_for_upstream)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
return Ok(BrightStaffError::InternalServerError(format!(
|
||||
"Failed to send request: {}",
|
||||
err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
// copy over the headers and status code from the original response
|
||||
let response_headers = llm_response.headers().clone();
|
||||
let upstream_status = llm_response.status();
|
||||
let mut response = Response::builder().status(upstream_status);
|
||||
let headers = response.headers_mut().unwrap();
|
||||
for (header_name, header_value) in response_headers.iter() {
|
||||
headers.insert(header_name, header_value.clone());
|
||||
}
|
||||
|
||||
// Build LLM span with actual status code using constants
|
||||
let byte_stream = llm_response.bytes_stream();
|
||||
|
||||
// Create base processor for metrics and tracing
|
||||
let base_processor = ObservableStreamProcessor::new(
|
||||
operation_component::LLM,
|
||||
span_name,
|
||||
request_start_time,
|
||||
messages_for_signals,
|
||||
);
|
||||
|
||||
// === v1/responses state management: Wrap with ResponsesStateProcessor ===
|
||||
// Pick the right processor: state-aware if needed, otherwise base metrics-only.
|
||||
let processor: Box<dyn StreamProcessor> = if let (true, false, Some(state_store)) = (
|
||||
should_manage_state,
|
||||
original_input_items.is_empty(),
|
||||
state_storage,
|
||||
) {
|
||||
let content_encoding = response_headers
|
||||
.get("content-encoding")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
Box::new(ResponsesStateProcessor::new(
|
||||
base_processor,
|
||||
state_store,
|
||||
original_input_items,
|
||||
alias_resolved_model.clone(),
|
||||
resolved_model.clone(),
|
||||
is_streaming_request,
|
||||
false, // Not OpenAI upstream since should_manage_state is true
|
||||
content_encoding,
|
||||
request_id,
|
||||
))
|
||||
} else {
|
||||
Box::new(base_processor)
|
||||
};
|
||||
|
||||
// Apply output filters if configured, then build the streaming response.
|
||||
let streaming_response = if has_output_filter {
|
||||
let output_chain = filter_pipeline.output.as_ref().unwrap().clone();
|
||||
create_streaming_response_with_output_filter(
|
||||
byte_stream,
|
||||
processor,
|
||||
output_chain,
|
||||
output_filter_request_headers.unwrap(),
|
||||
request_path.clone(),
|
||||
)
|
||||
} else {
|
||||
create_streaming_response(byte_stream, processor)
|
||||
};
|
||||
|
||||
match response.body(streaming_response.body) {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => Ok(BrightStaffError::InternalServerError(format!(
|
||||
"Failed to create response: {}",
|
||||
err
|
||||
))
|
||||
.into_response()),
|
||||
}
|
||||
}
|
||||
/// Resolves model aliases by looking up the requested model in the model_aliases map.
|
||||
/// Returns the target model if an alias is found, otherwise returns the original model.
|
||||
fn resolve_model_alias(
|
||||
model_from_request: &str,
|
||||
model_aliases: &Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
) -> String {
|
||||
if let Some(aliases) = model_aliases.as_ref() {
|
||||
if let Some(model_alias) = aliases.get(model_from_request) {
|
||||
debug!(
|
||||
"Model Alias: 'From {}' -> 'To {}'",
|
||||
model_from_request, model_alias.target
|
||||
);
|
||||
return model_alias.target.clone();
|
||||
}
|
||||
}
|
||||
model_from_request.to_string()
|
||||
}
|
||||
|
||||
/// Calculates the upstream path for the provider based on the model name.
|
||||
/// Looks up provider configuration, gets the ProviderId and base_url_path_prefix,
|
||||
/// then uses target_endpoint_for_provider to calculate the correct upstream path.
|
||||
async fn get_upstream_path(
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
model_name: &str,
|
||||
request_path: &str,
|
||||
resolved_model: &str,
|
||||
is_streaming: bool,
|
||||
) -> String {
|
||||
let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await;
|
||||
|
||||
// Calculate the upstream path using the proper API
|
||||
let client_api = SupportedAPIsFromClient::from_endpoint(request_path)
|
||||
.expect("Should have valid API endpoint");
|
||||
|
||||
client_api.target_endpoint_for_provider(
|
||||
&provider_id,
|
||||
request_path,
|
||||
resolved_model,
|
||||
is_streaming,
|
||||
base_url_path_prefix.as_deref(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Helper function to get provider info (ProviderId and base_url_path_prefix)
|
||||
async fn get_provider_info(
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
model_name: &str,
|
||||
) -> (hermesllm::ProviderId, Option<String>) {
|
||||
let providers_lock = llm_providers.read().await;
|
||||
|
||||
// Try to find by model name or provider name using LlmProviders::get
|
||||
// This handles both "gpt-4" and "openai/gpt-4" formats
|
||||
if let Some(provider) = providers_lock.get(model_name) {
|
||||
let provider_id = provider.provider_interface.to_provider_id();
|
||||
let prefix = provider.base_url_path_prefix.clone();
|
||||
return (provider_id, prefix);
|
||||
}
|
||||
|
||||
// Fall back to default provider
|
||||
if let Some(provider) = providers_lock.default() {
|
||||
let provider_id = provider.provider_interface.to_provider_id();
|
||||
let prefix = provider.base_url_path_prefix.clone();
|
||||
(provider_id, prefix)
|
||||
} else {
|
||||
// Last resort: use OpenAI as hardcoded fallback
|
||||
warn!("No default provider found, falling back to OpenAI");
|
||||
(hermesllm::ProviderId::OpenAI, None)
|
||||
}
|
||||
}
|
||||
|
||||
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||
Full::new(chunk.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
}
|
||||
780
crates/brightstaff/src/handlers/llm/mod.rs
Normal file
780
crates/brightstaff/src/handlers/llm/mod.rs
Normal file
|
|
@ -0,0 +1,780 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{FilterPipeline, ModelAlias};
|
||||
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER};
|
||||
use common::llm_providers::LlmProviders;
|
||||
use hermesllm::apis::openai::Message;
|
||||
use hermesllm::apis::openai_responses::InputParam;
|
||||
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use opentelemetry::global;
|
||||
use opentelemetry::trace::get_active_span;
|
||||
use opentelemetry_http::HeaderInjector;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
||||
pub(crate) mod model_selection;
|
||||
|
||||
use crate::app_state::AppState;
|
||||
use crate::handlers::agents::pipeline::PipelineProcessor;
|
||||
use crate::handlers::extract_or_generate_traceparent;
|
||||
use crate::handlers::extract_request_id;
|
||||
use crate::handlers::full;
|
||||
use crate::state::response_state_processor::ResponsesStateProcessor;
|
||||
use crate::state::{
|
||||
extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
|
||||
};
|
||||
use crate::streaming::{
|
||||
create_streaming_response, create_streaming_response_with_output_filter, truncate_message,
|
||||
ObservableStreamProcessor, StreamProcessor,
|
||||
};
|
||||
use crate::tracing::{
|
||||
collect_custom_trace_attributes, llm as tracing_llm, operation_component, set_service_name,
|
||||
};
|
||||
use model_selection::router_chat_get_upstream_model;
|
||||
|
||||
pub async fn llm_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
state: Arc<AppState>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_path = request.uri().path().to_string();
|
||||
let request_headers = request.headers().clone();
|
||||
let request_id = extract_request_id(&request);
|
||||
let custom_attrs =
|
||||
collect_custom_trace_attributes(&request_headers, state.span_attributes.as_ref());
|
||||
|
||||
// Create a span with request_id that will be included in all log lines
|
||||
let request_span = info_span!(
|
||||
"llm",
|
||||
component = "llm",
|
||||
request_id = %request_id,
|
||||
http.method = %request.method(),
|
||||
http.path = %request_path,
|
||||
llm.model = tracing::field::Empty,
|
||||
llm.tools = tracing::field::Empty,
|
||||
llm.user_message_preview = tracing::field::Empty,
|
||||
llm.temperature = tracing::field::Empty,
|
||||
);
|
||||
|
||||
// Execute the rest of the handler inside the span
|
||||
llm_chat_inner(
|
||||
request,
|
||||
state,
|
||||
custom_attrs,
|
||||
request_id,
|
||||
request_path,
|
||||
request_headers,
|
||||
)
|
||||
.instrument(request_span)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn llm_chat_inner(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
state: Arc<AppState>,
|
||||
custom_attrs: HashMap<String, String>,
|
||||
request_id: String,
|
||||
request_path: String,
|
||||
mut request_headers: hyper::HeaderMap,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
// Set service name for LLM operations
|
||||
set_service_name(operation_component::LLM);
|
||||
get_active_span(|span| {
|
||||
for (key, value) in &custom_attrs {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
|
||||
}
|
||||
});
|
||||
|
||||
let traceparent = extract_or_generate_traceparent(&request_headers);
|
||||
|
||||
let full_qualified_llm_provider_url = format!("{}{}", state.llm_provider_url, request_path);
|
||||
|
||||
// --- Phase 1: Parse and validate the incoming request ---
|
||||
let parsed = match parse_and_validate_request(
|
||||
request,
|
||||
&request_path,
|
||||
&state.model_aliases,
|
||||
&state.llm_providers,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(response) => return Ok(response),
|
||||
};
|
||||
|
||||
let PreparedRequest {
|
||||
mut client_request,
|
||||
chat_request_bytes,
|
||||
model_from_request,
|
||||
alias_resolved_model,
|
||||
model_name_only,
|
||||
is_streaming_request,
|
||||
is_responses_api_client,
|
||||
messages_for_signals,
|
||||
temperature,
|
||||
tool_names,
|
||||
user_message_preview,
|
||||
inline_routing_policy,
|
||||
client_api,
|
||||
provider_id,
|
||||
} = parsed;
|
||||
|
||||
// Record LLM-specific span attributes
|
||||
let span = tracing::Span::current();
|
||||
if let Some(temp) = temperature {
|
||||
span.record(tracing_llm::TEMPERATURE, tracing::field::display(temp));
|
||||
}
|
||||
if let Some(tools) = &tool_names {
|
||||
let formatted_tools = tools
|
||||
.iter()
|
||||
.map(|name| format!("{}(...)", name))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
span.record(tracing_llm::TOOLS, formatted_tools.as_str());
|
||||
}
|
||||
if let Some(preview) = &user_message_preview {
|
||||
span.record(tracing_llm::USER_MESSAGE_PREVIEW, preview.as_str());
|
||||
}
|
||||
|
||||
// --- Phase 1b: Input filter processing for model listener ---
|
||||
if let Some(ref input_chain) = state.filter_pipeline.input {
|
||||
if !input_chain.is_empty() {
|
||||
debug!(input_filters = ?input_chain.filter_ids, "processing model listener input filters");
|
||||
let chain = input_chain.to_agent_filter_chain("model_listener");
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
match pipeline_processor
|
||||
.process_raw_filter_chain(
|
||||
&chat_request_bytes,
|
||||
&chain,
|
||||
&input_chain.agents,
|
||||
&request_headers,
|
||||
&request_path,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(filtered_bytes) => {
|
||||
let api_type = SupportedAPIsFromClient::from_endpoint(request_path.as_str())
|
||||
.expect("endpoint validated in parse_and_validate_request");
|
||||
match ProviderRequestType::try_from((&filtered_bytes[..], &api_type)) {
|
||||
Ok(updated_request) => {
|
||||
client_request = updated_request;
|
||||
info!("input filter chain processed successfully");
|
||||
}
|
||||
Err(parse_err) => {
|
||||
warn!(error = %parse_err, "input filter returned invalid request JSON");
|
||||
return Ok(common::errors::BrightStaffError::InvalidRequest(format!(
|
||||
"Input filter returned invalid request: {}",
|
||||
parse_err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(super::agents::pipeline::PipelineError::ClientError {
|
||||
agent,
|
||||
status,
|
||||
body,
|
||||
}) => {
|
||||
warn!(agent = %agent, status = %status, body = %body, "client error from filter chain");
|
||||
let error_json = serde_json::json!({
|
||||
"error": "FilterChainError",
|
||||
"agent": agent,
|
||||
"status": status,
|
||||
"agent_response": body
|
||||
});
|
||||
let mut error_response = Response::new(full(error_json.to_string()));
|
||||
*error_response.status_mut() =
|
||||
StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_REQUEST);
|
||||
error_response.headers_mut().insert(
|
||||
hyper::header::CONTENT_TYPE,
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
return Ok(error_response);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(error = %err, "filter chain processing failed");
|
||||
let mut internal_error =
|
||||
Response::new(full(format!("Filter chain processing failed: {}", err)));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize for upstream after input filters
|
||||
if let Some(ref client_api_kind) = client_api {
|
||||
let upstream_api =
|
||||
provider_id.compatible_api_for_client(client_api_kind, is_streaming_request);
|
||||
client_request.normalize_for_upstream(provider_id, &upstream_api);
|
||||
}
|
||||
|
||||
// --- Phase 2: Resolve conversation state (v1/responses API) ---
|
||||
let state_ctx = match resolve_conversation_state(
|
||||
&mut client_request,
|
||||
is_responses_api_client,
|
||||
&state.state_storage,
|
||||
&state.llm_providers,
|
||||
&alias_resolved_model,
|
||||
&request_path,
|
||||
is_streaming_request,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(ctx) => ctx,
|
||||
Err(response) => return Ok(response),
|
||||
};
|
||||
|
||||
// Serialize request for upstream BEFORE router consumes it
|
||||
let client_request_bytes_for_upstream: Bytes =
|
||||
match ProviderRequestType::to_bytes(&client_request) {
|
||||
Ok(bytes) => bytes.into(),
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to serialize request for upstream");
|
||||
let mut r = Response::new(full(format!("Failed to serialize request: {}", err)));
|
||||
*r.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(r);
|
||||
}
|
||||
};
|
||||
|
||||
// --- Phase 3: Route the request ---
|
||||
let routing_span = info_span!(
|
||||
"routing",
|
||||
component = "routing",
|
||||
http.method = "POST",
|
||||
http.target = %request_path,
|
||||
model.requested = %model_from_request,
|
||||
model.alias_resolved = %alias_resolved_model,
|
||||
route.selected_model = tracing::field::Empty,
|
||||
routing.determination_ms = tracing::field::Empty,
|
||||
);
|
||||
let routing_result = match async {
|
||||
set_service_name(operation_component::ROUTING);
|
||||
router_chat_get_upstream_model(
|
||||
Arc::clone(&state.router_service),
|
||||
client_request,
|
||||
&traceparent,
|
||||
&request_path,
|
||||
&request_id,
|
||||
inline_routing_policy,
|
||||
)
|
||||
.await
|
||||
}
|
||||
.instrument(routing_span)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
let mut internal_error = Response::new(full(err.message));
|
||||
*internal_error.status_mut() = err.status_code;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
// Determine final model (router returns "none" when it doesn't select a specific model)
|
||||
let router_selected_model = routing_result.model_name;
|
||||
let resolved_model = if router_selected_model != "none" {
|
||||
router_selected_model
|
||||
} else {
|
||||
alias_resolved_model.clone()
|
||||
};
|
||||
tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str());
|
||||
|
||||
// --- Phase 4: Forward to upstream and stream back ---
|
||||
send_upstream(
|
||||
&state.http_client,
|
||||
&full_qualified_llm_provider_url,
|
||||
&mut request_headers,
|
||||
client_request_bytes_for_upstream,
|
||||
&model_from_request,
|
||||
&alias_resolved_model,
|
||||
&resolved_model,
|
||||
&model_name_only,
|
||||
&request_path,
|
||||
is_streaming_request,
|
||||
messages_for_signals,
|
||||
state_ctx,
|
||||
state.state_storage.clone(),
|
||||
request_id,
|
||||
&state.filter_pipeline,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Phase 1 — Parse & validate the incoming request
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// All pre-validated request data extracted from the raw HTTP request.
|
||||
struct PreparedRequest {
|
||||
client_request: ProviderRequestType,
|
||||
chat_request_bytes: Bytes,
|
||||
model_from_request: String,
|
||||
alias_resolved_model: String,
|
||||
model_name_only: String,
|
||||
is_streaming_request: bool,
|
||||
is_responses_api_client: bool,
|
||||
messages_for_signals: Option<Vec<Message>>,
|
||||
temperature: Option<f32>,
|
||||
tool_names: Option<Vec<String>>,
|
||||
user_message_preview: Option<String>,
|
||||
inline_routing_policy: Option<Vec<common::configuration::ModelUsagePreference>>,
|
||||
client_api: Option<SupportedAPIsFromClient>,
|
||||
provider_id: hermesllm::ProviderId,
|
||||
}
|
||||
|
||||
/// Parse the body, resolve the model alias, and validate the model exists.
|
||||
///
|
||||
/// Returns `Err(Response)` for early-exit error responses (400 etc.).
|
||||
async fn parse_and_validate_request(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
request_path: &str,
|
||||
model_aliases: &Option<HashMap<String, ModelAlias>>,
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
) -> Result<PreparedRequest, Response<BoxBody<Bytes, hyper::Error>>> {
|
||||
let raw_bytes = request
|
||||
.collect()
|
||||
.await
|
||||
.map_err(|_| {
|
||||
let mut r = Response::new(full("Failed to read request body"));
|
||||
*r.status_mut() = StatusCode::BAD_REQUEST;
|
||||
r
|
||||
})?
|
||||
.to_bytes();
|
||||
|
||||
debug!(
|
||||
body = %String::from_utf8_lossy(&raw_bytes),
|
||||
"request body received"
|
||||
);
|
||||
|
||||
// Extract routing_policy from request body if present
|
||||
let (chat_request_bytes, inline_routing_policy) =
|
||||
crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false).map_err(
|
||||
|err| {
|
||||
warn!(error = %err, "failed to parse request JSON");
|
||||
let mut r = Response::new(full(format!("Failed to parse request: {}", err)));
|
||||
*r.status_mut() = StatusCode::BAD_REQUEST;
|
||||
r
|
||||
},
|
||||
)?;
|
||||
|
||||
let api_type = SupportedAPIsFromClient::from_endpoint(request_path).ok_or_else(|| {
|
||||
warn!(path = %request_path, "unsupported endpoint");
|
||||
let mut r = Response::new(full(format!("Unsupported endpoint: {}", request_path)));
|
||||
*r.status_mut() = StatusCode::BAD_REQUEST;
|
||||
r
|
||||
})?;
|
||||
|
||||
let mut client_request = ProviderRequestType::try_from((&chat_request_bytes[..], &api_type))
|
||||
.map_err(|err| {
|
||||
warn!(error = %err, "failed to parse request as ProviderRequestType");
|
||||
let mut r = Response::new(full(format!("Failed to parse request: {}", err)));
|
||||
*r.status_mut() = StatusCode::BAD_REQUEST;
|
||||
r
|
||||
})?;
|
||||
|
||||
let is_responses_api_client =
|
||||
matches!(api_type, SupportedAPIsFromClient::OpenAIResponsesAPI(_));
|
||||
let client_api = Some(api_type);
|
||||
|
||||
let model_from_request = client_request.model().to_string();
|
||||
let temperature = client_request.get_temperature();
|
||||
let is_streaming_request = client_request.is_streaming();
|
||||
let alias_resolved_model = resolve_model_alias(&model_from_request, model_aliases);
|
||||
let (provider_id, _) = get_provider_info(llm_providers, &alias_resolved_model).await;
|
||||
|
||||
// Validate model exists in configuration
|
||||
if llm_providers
|
||||
.read()
|
||||
.await
|
||||
.get(&alias_resolved_model)
|
||||
.is_none()
|
||||
{
|
||||
let err_msg = format!(
|
||||
"Model '{}' not found in configured providers",
|
||||
alias_resolved_model
|
||||
);
|
||||
warn!(model = %alias_resolved_model, "model not found in configured providers");
|
||||
let mut r = Response::new(full(err_msg));
|
||||
*r.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Err(r);
|
||||
}
|
||||
|
||||
// Strip provider prefix for upstream (e.g. "openai/gpt-4" → "gpt-4")
|
||||
let model_name_only = alias_resolved_model
|
||||
.split_once('/')
|
||||
.map(|(_, model)| model.to_string())
|
||||
.unwrap_or_else(|| alias_resolved_model.clone());
|
||||
|
||||
// Extract span attributes and messages before mutating client_request
|
||||
let tool_names = client_request.get_tool_names();
|
||||
let user_message_preview = client_request
|
||||
.get_recent_user_message()
|
||||
.map(|msg| truncate_message(&msg, 50));
|
||||
let messages_for_signals = Some(client_request.get_messages());
|
||||
|
||||
// Set the upstream model name and strip routing metadata
|
||||
client_request.set_model(model_name_only.clone());
|
||||
if client_request.remove_metadata_key("archgw_preference_config") {
|
||||
debug!("removed archgw_preference_config from metadata");
|
||||
}
|
||||
if client_request.remove_metadata_key("plano_preference_config") {
|
||||
debug!("removed plano_preference_config from metadata");
|
||||
}
|
||||
|
||||
Ok(PreparedRequest {
|
||||
client_request,
|
||||
chat_request_bytes,
|
||||
model_from_request,
|
||||
alias_resolved_model,
|
||||
model_name_only,
|
||||
is_streaming_request,
|
||||
is_responses_api_client,
|
||||
messages_for_signals,
|
||||
temperature,
|
||||
tool_names,
|
||||
user_message_preview,
|
||||
inline_routing_policy,
|
||||
client_api,
|
||||
provider_id,
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Phase 2 — Resolve conversation state (v1/responses API)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Holds the state management context resolved from a v1/responses request.
|
||||
struct ConversationStateContext {
|
||||
should_manage_state: bool,
|
||||
original_input_items: Vec<hermesllm::apis::openai_responses::InputItem>,
|
||||
}
|
||||
|
||||
/// If the client uses the v1/responses API and the upstream provider doesn't
|
||||
/// support it natively, we manage conversation state ourselves.
|
||||
///
|
||||
/// This resolves `previous_response_id`, merges conversation history, and
|
||||
/// updates the request in place.
|
||||
///
|
||||
/// Returns `Err(Response)` for early-exit (e.g. 409 Conflict).
|
||||
async fn resolve_conversation_state(
|
||||
client_request: &mut ProviderRequestType,
|
||||
is_responses_api_client: bool,
|
||||
state_storage: &Option<Arc<dyn StateStorage>>,
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
alias_resolved_model: &str,
|
||||
request_path: &str,
|
||||
is_streaming_request: bool,
|
||||
) -> Result<ConversationStateContext, Response<BoxBody<Bytes, hyper::Error>>> {
|
||||
if !is_responses_api_client {
|
||||
return Ok(ConversationStateContext {
|
||||
should_manage_state: false,
|
||||
original_input_items: Vec::new(),
|
||||
});
|
||||
}
|
||||
|
||||
let (responses_req, state_store) = match (client_request, state_storage) {
|
||||
(ProviderRequestType::ResponsesAPIRequest(ref mut req), Some(store)) => (req, store),
|
||||
_ => {
|
||||
return Ok(ConversationStateContext {
|
||||
should_manage_state: false,
|
||||
original_input_items: Vec::new(),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let mut original_input_items = extract_input_items(&responses_req.input);
|
||||
|
||||
// Check whether the upstream supports v1/responses natively
|
||||
let upstream_path = get_upstream_path(
|
||||
llm_providers,
|
||||
alias_resolved_model,
|
||||
request_path,
|
||||
alias_resolved_model,
|
||||
is_streaming_request,
|
||||
)
|
||||
.await;
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
|
||||
let should_manage_state = !matches!(
|
||||
upstream_api,
|
||||
Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))
|
||||
);
|
||||
|
||||
if !should_manage_state {
|
||||
debug!("upstream supports ResponsesAPI natively");
|
||||
return Ok(ConversationStateContext {
|
||||
should_manage_state: false,
|
||||
original_input_items,
|
||||
});
|
||||
}
|
||||
|
||||
// Retrieve and combine conversation history if previous_response_id exists
|
||||
if let Some(ref prev_resp_id) = responses_req.previous_response_id {
|
||||
match retrieve_and_combine_input(state_store.clone(), prev_resp_id, original_input_items)
|
||||
.await
|
||||
{
|
||||
Ok(combined_input) => {
|
||||
responses_req.input = InputParam::Items(combined_input.clone());
|
||||
original_input_items = combined_input;
|
||||
info!(
|
||||
items = original_input_items.len(),
|
||||
"updated request with conversation history"
|
||||
);
|
||||
}
|
||||
Err(StateStorageError::NotFound(_)) => {
|
||||
warn!(previous_response_id = %prev_resp_id, "previous response_id not found");
|
||||
let err_msg = format!(
|
||||
"Conversation state not found for previous_response_id: {}",
|
||||
prev_resp_id
|
||||
);
|
||||
let mut r = Response::new(full(err_msg));
|
||||
*r.status_mut() = StatusCode::CONFLICT;
|
||||
return Err(r);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
previous_response_id = %prev_resp_id,
|
||||
error = %e,
|
||||
"failed to retrieve conversation state"
|
||||
);
|
||||
// Restore original_input_items since we passed ownership
|
||||
original_input_items = extract_input_items(&responses_req.input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ConversationStateContext {
|
||||
should_manage_state,
|
||||
original_input_items,
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Phase 4 — Forward to upstream and stream the response back
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn send_upstream(
|
||||
http_client: &reqwest::Client,
|
||||
upstream_url: &str,
|
||||
request_headers: &mut hyper::HeaderMap,
|
||||
body: bytes::Bytes,
|
||||
model_from_request: &str,
|
||||
alias_resolved_model: &str,
|
||||
resolved_model: &str,
|
||||
model_name_only: &str,
|
||||
request_path: &str,
|
||||
is_streaming_request: bool,
|
||||
messages_for_signals: Option<Vec<Message>>,
|
||||
state_ctx: ConversationStateContext,
|
||||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
request_id: String,
|
||||
filter_pipeline: &Arc<FilterPipeline>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let span_name = if model_from_request == resolved_model {
|
||||
format!("POST {} {}", request_path, resolved_model)
|
||||
} else {
|
||||
format!(
|
||||
"POST {} {} -> {}",
|
||||
request_path, model_from_request, resolved_model
|
||||
)
|
||||
};
|
||||
get_active_span(|span| {
|
||||
span.update_name(span_name.clone());
|
||||
});
|
||||
|
||||
debug!(
|
||||
url = %upstream_url,
|
||||
provider_hint = %resolved_model,
|
||||
upstream_model = %model_name_only,
|
||||
"Routing to upstream"
|
||||
);
|
||||
|
||||
if let Ok(val) = header::HeaderValue::from_str(resolved_model) {
|
||||
request_headers.insert(ARCH_PROVIDER_HINT_HEADER, val);
|
||||
}
|
||||
request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
|
||||
header::HeaderValue::from_static(if is_streaming_request {
|
||||
"true"
|
||||
} else {
|
||||
"false"
|
||||
}),
|
||||
);
|
||||
request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
// Inject current span's trace context so upstream spans are children of plano(llm)
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx = tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(request_headers));
|
||||
});
|
||||
|
||||
let request_start_time = std::time::Instant::now();
|
||||
|
||||
let llm_response = match http_client
|
||||
.post(upstream_url)
|
||||
.headers(request_headers.clone())
|
||||
.body(body)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to send request: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
// Propagate upstream headers and status
|
||||
let response_headers = llm_response.headers().clone();
|
||||
let upstream_status = llm_response.status();
|
||||
let mut response = Response::builder().status(upstream_status);
|
||||
if let Some(headers) = response.headers_mut() {
|
||||
for (name, value) in response_headers.iter() {
|
||||
headers.insert(name, value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let byte_stream = llm_response.bytes_stream();
|
||||
|
||||
// Create base processor for metrics and tracing
|
||||
let base_processor = ObservableStreamProcessor::new(
|
||||
operation_component::LLM,
|
||||
span_name,
|
||||
request_start_time,
|
||||
messages_for_signals,
|
||||
);
|
||||
|
||||
let output_filter_request_headers = if filter_pipeline.has_output_filters() {
|
||||
Some(request_headers.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Pick the right processor: state-aware if needed, otherwise base metrics-only.
|
||||
let processor: Box<dyn StreamProcessor> = if let (true, false, Some(state_store)) = (
|
||||
state_ctx.should_manage_state,
|
||||
state_ctx.original_input_items.is_empty(),
|
||||
state_storage,
|
||||
) {
|
||||
let content_encoding = response_headers
|
||||
.get("content-encoding")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
Box::new(ResponsesStateProcessor::new(
|
||||
base_processor,
|
||||
state_store,
|
||||
state_ctx.original_input_items,
|
||||
alias_resolved_model.to_string(),
|
||||
resolved_model.to_string(),
|
||||
is_streaming_request,
|
||||
false,
|
||||
content_encoding,
|
||||
request_id,
|
||||
))
|
||||
} else {
|
||||
Box::new(base_processor)
|
||||
};
|
||||
|
||||
let streaming_response = if let (Some(output_chain), Some(filter_headers)) = (
|
||||
filter_pipeline.output.as_ref().filter(|c| !c.is_empty()),
|
||||
output_filter_request_headers,
|
||||
) {
|
||||
create_streaming_response_with_output_filter(
|
||||
byte_stream,
|
||||
processor,
|
||||
output_chain.clone(),
|
||||
filter_headers,
|
||||
request_path.to_string(),
|
||||
)
|
||||
} else {
|
||||
create_streaming_response(byte_stream, processor)
|
||||
};
|
||||
|
||||
match response.body(streaming_response.body) {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to create response: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
Ok(internal_error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Resolves model aliases by looking up the requested model in the model_aliases map.
|
||||
/// Returns the target model if an alias is found, otherwise returns the original model.
|
||||
fn resolve_model_alias(
|
||||
model_from_request: &str,
|
||||
model_aliases: &Option<HashMap<String, ModelAlias>>,
|
||||
) -> String {
|
||||
if let Some(aliases) = model_aliases.as_ref() {
|
||||
if let Some(model_alias) = aliases.get(model_from_request) {
|
||||
debug!(
|
||||
"Model Alias: 'From {}' -> 'To {}'",
|
||||
model_from_request, model_alias.target
|
||||
);
|
||||
return model_alias.target.clone();
|
||||
}
|
||||
}
|
||||
model_from_request.to_string()
|
||||
}
|
||||
|
||||
/// Calculates the upstream path for the provider based on the model name.
|
||||
async fn get_upstream_path(
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
model_name: &str,
|
||||
request_path: &str,
|
||||
resolved_model: &str,
|
||||
is_streaming: bool,
|
||||
) -> String {
|
||||
let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await;
|
||||
|
||||
let Some(client_api) = SupportedAPIsFromClient::from_endpoint(request_path) else {
|
||||
return request_path.to_string();
|
||||
};
|
||||
|
||||
client_api.target_endpoint_for_provider(
|
||||
&provider_id,
|
||||
request_path,
|
||||
resolved_model,
|
||||
is_streaming,
|
||||
base_url_path_prefix.as_deref(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Helper to get provider info (ProviderId and base_url_path_prefix).
|
||||
async fn get_provider_info(
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
model_name: &str,
|
||||
) -> (hermesllm::ProviderId, Option<String>) {
|
||||
let providers_lock = llm_providers.read().await;
|
||||
|
||||
if let Some(provider) = providers_lock.get(model_name) {
|
||||
let provider_id = provider.provider_interface.to_provider_id();
|
||||
let prefix = provider.base_url_path_prefix.clone();
|
||||
return (provider_id, prefix);
|
||||
}
|
||||
|
||||
if let Some(provider) = providers_lock.default() {
|
||||
let provider_id = provider.provider_interface.to_provider_id();
|
||||
let prefix = provider.base_url_path_prefix.clone();
|
||||
(provider_id, prefix)
|
||||
} else {
|
||||
warn!("No default provider found, falling back to OpenAI");
|
||||
(hermesllm::ProviderId::OpenAI, None)
|
||||
}
|
||||
}
|
||||
|
|
@ -5,7 +5,8 @@ use hyper::StatusCode;
|
|||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::router::llm::RouterService;
|
||||
use crate::streaming::truncate_message;
|
||||
use crate::tracing::routing;
|
||||
|
||||
pub struct RoutingResult {
|
||||
|
|
@ -103,16 +104,7 @@ pub async fn router_chat_get_upstream_model(
|
|||
.map_or("None".to_string(), |c| c.to_string().replace('\n', "\\n"))
|
||||
});
|
||||
|
||||
const MAX_MESSAGE_LENGTH: usize = 50;
|
||||
let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH {
|
||||
let truncated: String = latest_message_for_log
|
||||
.chars()
|
||||
.take(MAX_MESSAGE_LENGTH)
|
||||
.collect();
|
||||
format!("{}...", truncated)
|
||||
} else {
|
||||
latest_message_for_log
|
||||
};
|
||||
let latest_message_for_log = truncate_message(&latest_message_for_log, 50);
|
||||
|
||||
info!(
|
||||
has_usage_preferences = usage_preferences.is_some(),
|
||||
|
|
@ -1,14 +1,57 @@
|
|||
pub mod agent_chat_completions;
|
||||
pub mod agent_selector;
|
||||
pub mod agents;
|
||||
pub mod function_calling;
|
||||
pub mod jsonrpc;
|
||||
pub mod llm;
|
||||
pub mod models;
|
||||
pub mod pipeline_processor;
|
||||
pub mod response_handler;
|
||||
pub mod router_chat;
|
||||
pub mod response;
|
||||
pub mod routing_service;
|
||||
pub mod streaming;
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration_tests;
|
||||
|
||||
use bytes::Bytes;
|
||||
use common::consts::TRACE_PARENT_HEADER;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Empty, Full};
|
||||
use hyper::Request;
|
||||
use tracing::warn;
|
||||
|
||||
/// Wrap a chunk into a `BoxBody` for hyper responses.
|
||||
pub fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||
Full::new(chunk.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// An empty HTTP body (used for 404 / OPTIONS responses).
|
||||
pub fn empty() -> BoxBody<Bytes, hyper::Error> {
|
||||
Empty::<Bytes>::new()
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// Extract request ID from incoming request headers, or generate a new UUID v4.
|
||||
pub fn extract_request_id<T>(request: &Request<T>) -> String {
|
||||
request
|
||||
.headers()
|
||||
.get(common::consts::REQUEST_ID_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
|
||||
}
|
||||
|
||||
/// Extract or generate a W3C `traceparent` header value.
|
||||
pub fn extract_or_generate_traceparent(headers: &hyper::HeaderMap) -> String {
|
||||
headers
|
||||
.get(TRACE_PARENT_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| {
|
||||
let trace_id = uuid::Uuid::new_v4().to_string().replace("-", "");
|
||||
let tp = format!("00-{}-0000000000000000-01", trace_id);
|
||||
warn!(
|
||||
generated_traceparent = %tp,
|
||||
"TRACE_PARENT header missing, generated new traceparent"
|
||||
);
|
||||
tp
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
use bytes::Bytes;
|
||||
use common::llm_providers::LlmProviders;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use hyper::{Response, StatusCode};
|
||||
use serde_json;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::full;
|
||||
|
||||
pub async fn list_models(
|
||||
llm_providers: Arc<tokio::sync::RwLock<LlmProviders>>,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
|
|
@ -12,27 +13,15 @@ pub async fn list_models(
|
|||
let models = prov.to_models();
|
||||
|
||||
match serde_json::to_string(&models) {
|
||||
Ok(json) => {
|
||||
let body = Full::new(Bytes::from(json))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.unwrap()
|
||||
}
|
||||
Err(_) => {
|
||||
let body = Full::new(Bytes::from_static(
|
||||
b"{\"error\":\"Failed to serialize models\"}",
|
||||
))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.unwrap()
|
||||
}
|
||||
Ok(json) => Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(full(json))
|
||||
.unwrap(),
|
||||
Err(_) => Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(full("{\"error\":\"Failed to serialize models\"}"))
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use hermesllm::apis::OpenAIApi;
|
|||
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use hermesllm::SseEvent;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use http_body_util::StreamBody;
|
||||
use hyper::body::Frame;
|
||||
use hyper::Response;
|
||||
use tokio::sync::mpsc;
|
||||
|
|
@ -12,6 +12,8 @@ use tokio_stream::wrappers::ReceiverStream;
|
|||
use tokio_stream::StreamExt;
|
||||
use tracing::{info, warn, Instrument};
|
||||
|
||||
use super::full;
|
||||
|
||||
/// Service for handling HTTP responses and streaming
|
||||
pub struct ResponseHandler;
|
||||
|
||||
|
|
@ -22,9 +24,40 @@ impl ResponseHandler {
|
|||
|
||||
/// Create a full response body from bytes
|
||||
pub fn create_full_body<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||
Full::new(chunk.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
full(chunk)
|
||||
}
|
||||
|
||||
/// Create a JSON error response with BAD_REQUEST status
|
||||
pub fn create_json_error_response(
|
||||
json: &serde_json::Value,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let body = Self::create_full_body(json.to_string());
|
||||
let mut response = Response::new(body);
|
||||
*response.status_mut() = hyper::StatusCode::BAD_REQUEST;
|
||||
response.headers_mut().insert(
|
||||
hyper::header::CONTENT_TYPE,
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
response
|
||||
}
|
||||
|
||||
/// Create a BAD_REQUEST error response with a message
|
||||
pub fn create_bad_request(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let json = serde_json::json!({"error": message});
|
||||
Self::create_json_error_response(&json)
|
||||
}
|
||||
|
||||
/// Create an INTERNAL_SERVER_ERROR response with a message
|
||||
pub fn create_internal_error(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let json = serde_json::json!({"error": message});
|
||||
let body = Self::create_full_body(json.to_string());
|
||||
let mut response = Response::new(body);
|
||||
*response.status_mut() = hyper::StatusCode::INTERNAL_SERVER_ERROR;
|
||||
response.headers_mut().insert(
|
||||
hyper::header::CONTENT_TYPE,
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
response
|
||||
}
|
||||
|
||||
/// Create a streaming response from a reqwest response.
|
||||
|
|
@ -112,7 +145,9 @@ impl ResponseHandler {
|
|||
let upstream_api =
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
let sse_iter = SseStreamIter::try_from(response_bytes.as_ref()).unwrap();
|
||||
let sse_iter = SseStreamIter::try_from(response_bytes.as_ref()).map_err(|e| {
|
||||
BrightStaffError::StreamError(format!("Failed to parse SSE stream: {}", e))
|
||||
})?;
|
||||
let mut accumulated_text = String::new();
|
||||
|
||||
for sse_event in sse_iter {
|
||||
|
|
@ -122,7 +157,13 @@ impl ResponseHandler {
|
|||
}
|
||||
|
||||
let transformed_event =
|
||||
SseEvent::try_from((sse_event, &client_api, &upstream_api)).unwrap();
|
||||
match SseEvent::try_from((sse_event, &client_api, &upstream_api)) {
|
||||
Ok(event) => event,
|
||||
Err(e) => {
|
||||
warn!(error = ?e, "failed to transform SSE event, skipping");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Try to get provider response and extract content delta
|
||||
match transformed_event.provider_response() {
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{ModelUsagePreference, SpanAttributes};
|
||||
use common::consts::{REQUEST_ID_HEADER, TRACE_PARENT_HEADER};
|
||||
use common::consts::REQUEST_ID_HEADER;
|
||||
use common::errors::BrightStaffError;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
use hermesllm::ProviderRequestType;
|
||||
|
|
@ -10,8 +10,9 @@ use hyper::{Request, Response, StatusCode};
|
|||
use std::sync::Arc;
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
||||
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
||||
use crate::router::llm_router::RouterService;
|
||||
use super::extract_or_generate_traceparent;
|
||||
use crate::handlers::llm::model_selection::router_chat_get_upstream_model;
|
||||
use crate::router::llm::RouterService;
|
||||
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
|
||||
|
||||
const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120;
|
||||
|
|
@ -72,7 +73,7 @@ pub async fn routing_decision(
|
|||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
request_path: String,
|
||||
span_attributes: Arc<Option<SpanAttributes>>,
|
||||
span_attributes: &Option<SpanAttributes>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_headers = request.headers().clone();
|
||||
let request_id: String = request_headers
|
||||
|
|
@ -81,8 +82,7 @@ pub async fn routing_decision(
|
|||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
|
||||
|
||||
let custom_attrs =
|
||||
collect_custom_trace_attributes(&request_headers, span_attributes.as_ref().as_ref());
|
||||
let custom_attrs = collect_custom_trace_attributes(&request_headers, span_attributes.as_ref());
|
||||
|
||||
let request_span = info_span!(
|
||||
"routing_decision",
|
||||
|
|
@ -119,23 +119,7 @@ async fn routing_decision_inner(
|
|||
}
|
||||
});
|
||||
|
||||
// Extract or generate traceparent
|
||||
let traceparent: String = match request_headers
|
||||
.get(TRACE_PARENT_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
{
|
||||
Some(tp) => tp,
|
||||
None => {
|
||||
let trace_id = uuid::Uuid::new_v4().to_string().replace("-", "");
|
||||
let generated_tp = format!("00-{}-0000000000000000-01", trace_id);
|
||||
warn!(
|
||||
generated_traceparent = %generated_tp,
|
||||
"TRACE_PARENT header missing, generated new traceparent"
|
||||
);
|
||||
generated_tp
|
||||
}
|
||||
};
|
||||
let traceparent = extract_or_generate_traceparent(&request_headers);
|
||||
|
||||
// Extract trace_id from traceparent (format: 00-{trace_id}-{span_id}-{flags})
|
||||
let trace_id = traceparent
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
pub mod app_state;
|
||||
pub mod handlers;
|
||||
pub mod router;
|
||||
pub mod signals;
|
||||
pub mod state;
|
||||
pub mod streaming;
|
||||
pub mod tracing;
|
||||
pub mod utils;
|
||||
|
|
|
|||
|
|
@ -1,28 +1,31 @@
|
|||
use brightstaff::handlers::agent_chat_completions::agent_chat;
|
||||
use brightstaff::app_state::AppState;
|
||||
use brightstaff::handlers::agents::orchestrator::agent_chat;
|
||||
use brightstaff::handlers::empty;
|
||||
use brightstaff::handlers::function_calling::function_calling_chat_handler;
|
||||
use brightstaff::handlers::llm::llm_chat;
|
||||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::handlers::routing_service::routing_decision;
|
||||
use brightstaff::router::llm_router::RouterService;
|
||||
use brightstaff::router::plano_orchestrator::OrchestratorService;
|
||||
use brightstaff::router::llm::RouterService;
|
||||
use brightstaff::router::orchestrator::OrchestratorService;
|
||||
use brightstaff::state::memory::MemoryConversationalStorage;
|
||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||
use brightstaff::state::StateStorage;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
use brightstaff::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{
|
||||
Agent, Configuration, FilterPipeline, ListenerType, ResolvedFilterChain,
|
||||
};
|
||||
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
|
||||
use common::llm_providers::LlmProviders;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::header::HeaderValue;
|
||||
use hyper::server::conn::http1;
|
||||
use hyper::service::service_fn;
|
||||
use hyper::{Method, Request, Response, StatusCode};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use opentelemetry::global;
|
||||
use opentelemetry::trace::FutureExt;
|
||||
use opentelemetry::{global, Context};
|
||||
use opentelemetry_http::HeaderExtractor;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
|
@ -31,82 +34,94 @@ use tokio::net::TcpListener;
|
|||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
pub mod router;
|
||||
|
||||
const BIND_ADDRESS: &str = "0.0.0.0:9091";
|
||||
const DEFAULT_ROUTING_LLM_PROVIDER: &str = "arch-router";
|
||||
const DEFAULT_ROUTING_MODEL_NAME: &str = "Arch-Router";
|
||||
const DEFAULT_ORCHESTRATOR_LLM_PROVIDER: &str = "plano-orchestrator";
|
||||
const DEFAULT_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator";
|
||||
|
||||
// Utility function to extract the context from the incoming request headers
|
||||
fn extract_context_from_request(req: &Request<Incoming>) -> Context {
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
propagator.extract(&HeaderExtractor(req.headers()))
|
||||
})
|
||||
/// CORS pre-flight response for the models endpoint.
|
||||
fn cors_preflight() -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let mut response = Response::new(empty());
|
||||
*response.status_mut() = StatusCode::NO_CONTENT;
|
||||
let h = response.headers_mut();
|
||||
h.insert("Allow", HeaderValue::from_static("GET, OPTIONS"));
|
||||
h.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
|
||||
h.insert(
|
||||
"Access-Control-Allow-Headers",
|
||||
HeaderValue::from_static("Authorization, Content-Type"),
|
||||
);
|
||||
h.insert(
|
||||
"Access-Control-Allow-Methods",
|
||||
HeaderValue::from_static("GET, POST, OPTIONS"),
|
||||
);
|
||||
h.insert("Content-Type", HeaderValue::from_static("application/json"));
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
fn empty() -> BoxBody<Bytes, hyper::Error> {
|
||||
Empty::<Bytes>::new()
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
}
|
||||
// ---------------------------------------------------------------------------
|
||||
// Configuration loading
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
|
||||
|
||||
// loading plano_config.yaml file (before tracing init so we can read tracing config)
|
||||
let plano_config_path = env::var("PLANO_CONFIG_PATH_RENDERED")
|
||||
/// Load and parse the YAML configuration file.
|
||||
///
|
||||
/// The path is read from `PLANO_CONFIG_PATH_RENDERED` (env) or falls back to
|
||||
/// `./plano_config_rendered.yaml`.
|
||||
fn load_config() -> Result<Configuration, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let path = env::var("PLANO_CONFIG_PATH_RENDERED")
|
||||
.unwrap_or_else(|_| "./plano_config_rendered.yaml".to_string());
|
||||
eprintln!("loading plano_config.yaml from {}", plano_config_path);
|
||||
eprintln!("loading plano_config.yaml from {}", path);
|
||||
|
||||
let config_contents =
|
||||
fs::read_to_string(&plano_config_path).expect("Failed to read plano_config.yaml");
|
||||
let contents = fs::read_to_string(&path).map_err(|e| format!("failed to read {path}: {e}"))?;
|
||||
|
||||
let config: Configuration =
|
||||
serde_yaml::from_str(&config_contents).expect("Failed to parse plano_config.yaml");
|
||||
serde_yaml::from_str(&contents).map_err(|e| format!("failed to parse {path}: {e}"))?;
|
||||
|
||||
// Initialize tracing using config.yaml tracing section
|
||||
let _tracer_provider = init_tracer(config.tracing.as_ref());
|
||||
info!(path = %plano_config_path, "loaded plano_config.yaml");
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
let plano_config = Arc::new(config);
|
||||
// ---------------------------------------------------------------------------
|
||||
// Application state initialization
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// combine agents and filters into a single list of agents
|
||||
let all_agents: Vec<Agent> = plano_config
|
||||
/// Build the shared [`AppState`] from a parsed [`Configuration`].
|
||||
async fn init_app_state(
|
||||
config: &Configuration,
|
||||
) -> Result<AppState, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let llm_provider_url =
|
||||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
|
||||
// Combine agents and filters into a single list
|
||||
let all_agents: Vec<Agent> = config
|
||||
.agents
|
||||
.as_deref()
|
||||
.unwrap_or_default()
|
||||
.iter()
|
||||
.chain(plano_config.filters.as_deref().unwrap_or_default())
|
||||
.chain(config.filters.as_deref().unwrap_or_default())
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
// Build global agent map for resolving filter chain references
|
||||
let global_agent_map: HashMap<String, Agent> = all_agents
|
||||
.iter()
|
||||
.map(|a| (a.id.clone(), a.clone()))
|
||||
.collect();
|
||||
|
||||
// Create expanded provider list for /v1/models endpoint
|
||||
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
|
||||
.expect("Failed to create LlmProviders");
|
||||
let llm_providers = Arc::new(RwLock::new(llm_providers));
|
||||
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
|
||||
let llm_providers = LlmProviders::try_from(config.model_providers.clone())
|
||||
.map_err(|e| format!("failed to create LlmProviders: {e}"))?;
|
||||
|
||||
// Resolve model listener filter chain and agents at startup
|
||||
let model_listener_count = plano_config
|
||||
let model_listener_count = config
|
||||
.listeners
|
||||
.iter()
|
||||
.filter(|l| l.listener_type == ListenerType::Model)
|
||||
.count();
|
||||
assert!(
|
||||
model_listener_count <= 1,
|
||||
"only one model listener is allowed, found {}",
|
||||
model_listener_count
|
||||
);
|
||||
let model_listener = plano_config
|
||||
if model_listener_count > 1 {
|
||||
return Err(format!(
|
||||
"only one model listener is allowed, found {}",
|
||||
model_listener_count
|
||||
)
|
||||
.into());
|
||||
}
|
||||
let model_listener = config
|
||||
.listeners
|
||||
.iter()
|
||||
.find(|l| l.listener_type == ListenerType::Model);
|
||||
|
|
@ -114,7 +129,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
filter_ids.map(|ids| {
|
||||
let agents = ids
|
||||
.iter()
|
||||
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
|
||||
.filter_map(|id| {
|
||||
global_agent_map
|
||||
.get(id)
|
||||
.map(|a: &Agent| (id.clone(), a.clone()))
|
||||
})
|
||||
.collect();
|
||||
ResolvedFilterChain {
|
||||
filter_ids: ids,
|
||||
|
|
@ -126,14 +145,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
input: resolve_chain(model_listener.and_then(|l| l.input_filters.clone())),
|
||||
output: resolve_chain(model_listener.and_then(|l| l.output_filters.clone())),
|
||||
});
|
||||
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
|
||||
let llm_provider_url =
|
||||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
let overrides = plano_config.overrides.clone().unwrap_or_default();
|
||||
let overrides = config.overrides.clone().unwrap_or_default();
|
||||
|
||||
// Strip provider prefix (e.g. "arch/") to get the model ID used in upstream requests
|
||||
let routing_model_name: String = overrides
|
||||
.llm_routing_model
|
||||
.as_deref()
|
||||
|
|
@ -141,21 +155,20 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
.unwrap_or(DEFAULT_ROUTING_MODEL_NAME)
|
||||
.to_string();
|
||||
|
||||
let routing_llm_provider = plano_config
|
||||
let routing_llm_provider = config
|
||||
.model_providers
|
||||
.iter()
|
||||
.find(|p| p.model.as_deref() == Some(routing_model_name.as_str()))
|
||||
.map(|p| p.name.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
||||
plano_config.model_providers.clone(),
|
||||
let router_service = Arc::new(RouterService::new(
|
||||
config.model_providers.clone(),
|
||||
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
||||
routing_model_name,
|
||||
routing_llm_provider,
|
||||
));
|
||||
|
||||
// Strip provider prefix (e.g. "arch/") to get the model ID used in upstream requests
|
||||
let orchestrator_model_name: String = overrides
|
||||
.agent_orchestration_model
|
||||
.as_deref()
|
||||
|
|
@ -163,213 +176,205 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
.unwrap_or(DEFAULT_ORCHESTRATOR_MODEL_NAME)
|
||||
.to_string();
|
||||
|
||||
let orchestrator_llm_provider: String = plano_config
|
||||
let orchestrator_llm_provider: String = config
|
||||
.model_providers
|
||||
.iter()
|
||||
.find(|p| p.model.as_deref() == Some(orchestrator_model_name.as_str()))
|
||||
.map(|p| p.name.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ORCHESTRATOR_LLM_PROVIDER.to_string());
|
||||
|
||||
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(
|
||||
let orchestrator_service = Arc::new(OrchestratorService::new(
|
||||
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
||||
orchestrator_model_name,
|
||||
orchestrator_llm_provider,
|
||||
));
|
||||
|
||||
let model_aliases = Arc::new(plano_config.model_aliases.clone());
|
||||
let span_attributes = Arc::new(
|
||||
plano_config
|
||||
.tracing
|
||||
.as_ref()
|
||||
.and_then(|tracing| tracing.span_attributes.clone()),
|
||||
);
|
||||
let state_storage = init_state_storage(config).await?;
|
||||
|
||||
// Initialize trace collector and start background flusher
|
||||
// Tracing is enabled if the tracing config is present in plano_config.yaml
|
||||
// Pass Some(true/false) to override, or None to use env var OTEL_TRACING_ENABLED
|
||||
// OpenTelemetry automatic instrumentation is configured in utils/tracing.rs
|
||||
let span_attributes = config
|
||||
.tracing
|
||||
.as_ref()
|
||||
.and_then(|tracing| tracing.span_attributes.clone());
|
||||
|
||||
// Initialize conversation state storage for v1/responses
|
||||
// Configurable via plano_config.yaml state_storage section
|
||||
// If not configured, state management is disabled
|
||||
// Environment variables are substituted by envsubst before config is read
|
||||
let state_storage: Option<Arc<dyn StateStorage>> =
|
||||
if let Some(storage_config) = &plano_config.state_storage {
|
||||
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
|
||||
common::configuration::StateStorageType::Memory => {
|
||||
info!(
|
||||
storage_type = "memory",
|
||||
"initialized conversation state storage"
|
||||
);
|
||||
Arc::new(MemoryConversationalStorage::new())
|
||||
}
|
||||
common::configuration::StateStorageType::Postgres => {
|
||||
let connection_string = storage_config
|
||||
.connection_string
|
||||
.as_ref()
|
||||
.expect("connection_string is required for postgres state_storage");
|
||||
Ok(AppState {
|
||||
router_service,
|
||||
orchestrator_service,
|
||||
model_aliases: config.model_aliases.clone(),
|
||||
llm_providers: Arc::new(RwLock::new(llm_providers)),
|
||||
agents_list: Some(all_agents),
|
||||
listeners: config.listeners.clone(),
|
||||
state_storage,
|
||||
llm_provider_url,
|
||||
span_attributes,
|
||||
http_client: reqwest::Client::new(),
|
||||
filter_pipeline,
|
||||
})
|
||||
}
|
||||
|
||||
debug!(connection_string = %connection_string, "postgres connection");
|
||||
info!(
|
||||
storage_type = "postgres",
|
||||
"initializing conversation state storage"
|
||||
);
|
||||
Arc::new(
|
||||
PostgreSQLConversationStorage::new(connection_string.clone())
|
||||
.await
|
||||
.expect("Failed to initialize Postgres state storage"),
|
||||
)
|
||||
}
|
||||
};
|
||||
Some(storage)
|
||||
} else {
|
||||
info!("no state_storage configured, conversation state management disabled");
|
||||
None
|
||||
};
|
||||
/// Initialize the conversation state storage backend (if configured).
|
||||
async fn init_state_storage(
|
||||
config: &Configuration,
|
||||
) -> Result<Option<Arc<dyn StateStorage>>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let Some(storage_config) = &config.state_storage else {
|
||||
info!("no state_storage configured, conversation state management disabled");
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
let peer_addr = stream.peer_addr()?;
|
||||
let io = TokioIo::new(stream);
|
||||
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
|
||||
common::configuration::StateStorageType::Memory => {
|
||||
info!(
|
||||
storage_type = "memory",
|
||||
"initialized conversation state storage"
|
||||
);
|
||||
Arc::new(MemoryConversationalStorage::new())
|
||||
}
|
||||
common::configuration::StateStorageType::Postgres => {
|
||||
let connection_string = storage_config
|
||||
.connection_string
|
||||
.as_ref()
|
||||
.ok_or("connection_string is required for postgres state_storage")?;
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
||||
let orchestrator_service: Arc<OrchestratorService> = Arc::clone(&orchestrator_service);
|
||||
let model_aliases: Arc<
|
||||
Option<std::collections::HashMap<String, common::configuration::ModelAlias>>,
|
||||
> = Arc::clone(&model_aliases);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
debug!(connection_string = %connection_string, "postgres connection");
|
||||
info!(
|
||||
storage_type = "postgres",
|
||||
"initializing conversation state storage"
|
||||
);
|
||||
|
||||
let llm_providers = llm_providers.clone();
|
||||
let agents_list = combined_agents_filters_list.clone();
|
||||
let filter_pipeline = filter_pipeline.clone();
|
||||
let listeners = listeners.clone();
|
||||
let span_attributes = span_attributes.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
let service = service_fn(move |req| {
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let orchestrator_service = Arc::clone(&orchestrator_service);
|
||||
let parent_cx = extract_context_from_request(&req);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
let llm_providers = llm_providers.clone();
|
||||
let model_aliases = Arc::clone(&model_aliases);
|
||||
let agents_list = agents_list.clone();
|
||||
let filter_pipeline = filter_pipeline.clone();
|
||||
let listeners = listeners.clone();
|
||||
let span_attributes = span_attributes.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
Arc::new(
|
||||
PostgreSQLConversationStorage::new(connection_string.clone())
|
||||
.await
|
||||
.map_err(|e| format!("failed to initialize Postgres state storage: {e}"))?,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
async move {
|
||||
let path = req.uri().path().to_string();
|
||||
// Check if path starts with /agents
|
||||
if path.starts_with("/agents") {
|
||||
// Check if it matches one of the agent API paths
|
||||
let stripped_path = path.strip_prefix("/agents").unwrap();
|
||||
if matches!(
|
||||
stripped_path,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
||||
) {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, stripped_path);
|
||||
return agent_chat(
|
||||
req,
|
||||
orchestrator_service,
|
||||
fully_qualified_url,
|
||||
agents_list,
|
||||
listeners,
|
||||
span_attributes,
|
||||
llm_providers,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
if let Some(stripped_path) = path.strip_prefix("/routing") {
|
||||
let stripped_path = stripped_path.to_string();
|
||||
if matches!(
|
||||
stripped_path.as_str(),
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
||||
) {
|
||||
return routing_decision(
|
||||
req,
|
||||
router_service,
|
||||
stripped_path,
|
||||
span_attributes,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
match (req.method(), path.as_str()) {
|
||||
(
|
||||
&Method::POST,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH,
|
||||
) => {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, path);
|
||||
llm_chat(
|
||||
req,
|
||||
router_service,
|
||||
fully_qualified_url,
|
||||
model_aliases,
|
||||
llm_providers,
|
||||
span_attributes,
|
||||
state_storage,
|
||||
filter_pipeline,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::POST, "/function_calling") => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, "/v1/chat/completions");
|
||||
function_calling_chat_handler(req, fully_qualified_url)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::GET, "/v1/models" | "/agents/v1/models") => {
|
||||
Ok(list_models(llm_providers).await)
|
||||
}
|
||||
// hack for now to get openw-web-ui to work
|
||||
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => {
|
||||
let mut response = Response::new(empty());
|
||||
*response.status_mut() = StatusCode::NO_CONTENT;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Allow", "GET, OPTIONS".parse().unwrap());
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Access-Control-Allow-Origin", "*".parse().unwrap());
|
||||
response.headers_mut().insert(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Authorization, Content-Type".parse().unwrap(),
|
||||
);
|
||||
response.headers_mut().insert(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET, POST, OPTIONS".parse().unwrap(),
|
||||
);
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Content-Type", "application/json".parse().unwrap());
|
||||
Ok(Some(storage))
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
_ => {
|
||||
debug!(method = %req.method(), path = %path, "no route found");
|
||||
let mut not_found = Response::new(empty());
|
||||
*not_found.status_mut() = StatusCode::NOT_FOUND;
|
||||
Ok(not_found)
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
// ---------------------------------------------------------------------------
|
||||
// Request routing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
debug!(peer = ?peer_addr, "accepted connection");
|
||||
if let Err(err) = http1::Builder::new()
|
||||
// .serve_connection(io, service_fn(chat_completion))
|
||||
.serve_connection(io, service)
|
||||
/// Route an incoming HTTP request to the appropriate handler.
|
||||
async fn route(
|
||||
req: Request<Incoming>,
|
||||
state: Arc<AppState>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let parent_cx = global::get_text_map_propagator(|p| p.extract(&HeaderExtractor(req.headers())));
|
||||
let path = req.uri().path().to_string();
|
||||
|
||||
// --- Agent routes (/agents/...) ---
|
||||
if let Some(stripped) = path.strip_prefix("/agents") {
|
||||
if matches!(
|
||||
stripped,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
||||
) {
|
||||
return agent_chat(req, Arc::clone(&state))
|
||||
.with_context(parent_cx)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// --- Routing decision routes (/routing/...) ---
|
||||
if let Some(stripped) = path.strip_prefix("/routing") {
|
||||
let stripped = stripped.to_string();
|
||||
if matches!(
|
||||
stripped.as_str(),
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
||||
) {
|
||||
return routing_decision(
|
||||
req,
|
||||
Arc::clone(&state.router_service),
|
||||
stripped,
|
||||
&state.span_attributes,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// --- Standard routes ---
|
||||
match (req.method(), path.as_str()) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
||||
llm_chat(req, Arc::clone(&state))
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
{
|
||||
warn!(error = ?err, "error serving connection");
|
||||
}
|
||||
});
|
||||
}
|
||||
(&Method::POST, "/function_calling") => {
|
||||
let url = format!("{}/v1/chat/completions", state.llm_provider_url);
|
||||
function_calling_chat_handler(req, url)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::GET, "/v1/models" | "/agents/v1/models") => {
|
||||
Ok(list_models(Arc::clone(&state.llm_providers)).await)
|
||||
}
|
||||
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => cors_preflight(),
|
||||
_ => {
|
||||
debug!(method = %req.method(), path = %path, "no route found");
|
||||
let mut not_found = Response::new(empty());
|
||||
*not_found.status_mut() = StatusCode::NOT_FOUND;
|
||||
Ok(not_found)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Server loop
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Accept connections and spawn a task per connection.
|
||||
///
|
||||
/// Listens for `SIGINT` / `ctrl-c` and shuts down gracefully, allowing
|
||||
/// in-flight connections to finish.
|
||||
async fn run_server(state: Arc<AppState>) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
|
||||
let listener = TcpListener::bind(&bind_address).await?;
|
||||
info!(address = %bind_address, "server listening");
|
||||
|
||||
let shutdown = tokio::signal::ctrl_c();
|
||||
tokio::pin!(shutdown);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = listener.accept() => {
|
||||
let (stream, _) = result?;
|
||||
let peer_addr = stream.peer_addr()?;
|
||||
let io = TokioIo::new(stream);
|
||||
let state = Arc::clone(&state);
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
debug!(peer = ?peer_addr, "accepted connection");
|
||||
|
||||
let service = service_fn(move |req| {
|
||||
let state = Arc::clone(&state);
|
||||
async move { route(req, state).await }
|
||||
});
|
||||
|
||||
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
|
||||
warn!(error = ?err, "error serving connection");
|
||||
}
|
||||
});
|
||||
}
|
||||
_ = &mut shutdown => {
|
||||
info!("received shutdown signal, stopping server");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Entry point
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let config = load_config()?;
|
||||
let _tracer_provider = init_tracer(config.tracing.as_ref());
|
||||
info!("loaded plano_config.yaml");
|
||||
let state = Arc::new(init_app_state(&config).await?);
|
||||
run_server(state).await
|
||||
}
|
||||
|
|
|
|||
48
crates/brightstaff/src/router/http.rs
Normal file
48
crates/brightstaff/src/router/http.rs
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
use hermesllm::apis::openai::ChatCompletionsResponse;
|
||||
use hyper::header;
|
||||
use thiserror::Error;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum HttpError {
|
||||
#[error("Failed to send request: {0}")]
|
||||
Request(#[from] reqwest::Error),
|
||||
|
||||
#[error("Failed to parse JSON response: {0}")]
|
||||
Json(serde_json::Error, String),
|
||||
}
|
||||
|
||||
/// Sends a POST request to the given URL and extracts the text content
|
||||
/// from the first choice of the `ChatCompletionsResponse`.
|
||||
///
|
||||
/// Returns `Some((content, elapsed))` on success, or `None` if the response
|
||||
/// had no choices or the first choice had no content.
|
||||
pub async fn post_and_extract_content(
|
||||
client: &reqwest::Client,
|
||||
url: &str,
|
||||
headers: header::HeaderMap,
|
||||
body: String,
|
||||
) -> Result<Option<(String, std::time::Duration)>, HttpError> {
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
let res = client.post(url).headers(headers).body(body).send().await?;
|
||||
|
||||
let body = res.text().await?;
|
||||
let elapsed = start_time.elapsed();
|
||||
|
||||
let response: ChatCompletionsResponse = serde_json::from_str(&body).map_err(|err| {
|
||||
warn!(error = %err, body = %body, "failed to parse json response");
|
||||
HttpError::Json(err, format!("Failed to parse JSON: {}", body))
|
||||
})?;
|
||||
|
||||
if response.choices.is_empty() {
|
||||
warn!(body = %body, "no choices in response");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(response.choices[0]
|
||||
.message
|
||||
.content
|
||||
.as_ref()
|
||||
.map(|c| (c.clone(), elapsed)))
|
||||
}
|
||||
148
crates/brightstaff/src/router/llm.rs
Normal file
148
crates/brightstaff/src/router/llm.rs
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use common::{
|
||||
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
|
||||
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER},
|
||||
};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use hyper::header;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::http::{self, post_and_extract_content};
|
||||
use super::router_model::RouterModel;
|
||||
|
||||
use crate::router::router_model_v1;
|
||||
|
||||
pub struct RouterService {
|
||||
router_url: String,
|
||||
client: reqwest::Client,
|
||||
router_model: Arc<dyn RouterModel>,
|
||||
routing_provider_name: String,
|
||||
llm_usage_defined: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RoutingError {
|
||||
#[error(transparent)]
|
||||
Http(#[from] http::HttpError),
|
||||
|
||||
#[error("Router model error: {0}")]
|
||||
RouterModelError(#[from] super::router_model::RoutingModelError),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, RoutingError>;
|
||||
|
||||
impl RouterService {
|
||||
pub fn new(
|
||||
providers: Vec<LlmProvider>,
|
||||
router_url: String,
|
||||
routing_model_name: String,
|
||||
routing_provider_name: String,
|
||||
) -> Self {
|
||||
let providers_with_usage = providers
|
||||
.iter()
|
||||
.filter(|provider| provider.routing_preferences.is_some())
|
||||
.cloned()
|
||||
.collect::<Vec<LlmProvider>>();
|
||||
|
||||
let llm_routes: HashMap<String, Vec<RoutingPreference>> = providers_with_usage
|
||||
.iter()
|
||||
.filter_map(|provider| {
|
||||
provider
|
||||
.routing_preferences
|
||||
.as_ref()
|
||||
.map(|prefs| (provider.name.clone(), prefs.clone()))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
|
||||
llm_routes,
|
||||
routing_model_name,
|
||||
router_model_v1::MAX_TOKEN_LEN,
|
||||
));
|
||||
|
||||
RouterService {
|
||||
router_url,
|
||||
client: reqwest::Client::new(),
|
||||
router_model,
|
||||
routing_provider_name,
|
||||
llm_usage_defined: !providers_with_usage.is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn determine_route(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
traceparent: &str,
|
||||
usage_preferences: Option<Vec<ModelUsagePreference>>,
|
||||
request_id: &str,
|
||||
) -> Result<Option<(String, String)>> {
|
||||
if messages.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if usage_preferences
|
||||
.as_ref()
|
||||
.is_none_or(|prefs| prefs.len() < 2)
|
||||
&& !self.llm_usage_defined
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let router_request = self
|
||||
.router_model
|
||||
.generate_request(messages, &usage_preferences);
|
||||
|
||||
debug!(
|
||||
model = %self.router_model.get_model_name(),
|
||||
endpoint = %self.router_url,
|
||||
"sending request to arch-router"
|
||||
);
|
||||
|
||||
let body = serde_json::to_string(&router_request)
|
||||
.map_err(super::router_model::RoutingModelError::from)?;
|
||||
debug!(body = %body, "arch router request");
|
||||
|
||||
let mut headers = header::HeaderMap::new();
|
||||
headers.insert(
|
||||
header::CONTENT_TYPE,
|
||||
header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
if let Ok(val) = header::HeaderValue::from_str(&self.routing_provider_name) {
|
||||
headers.insert(
|
||||
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
|
||||
val,
|
||||
);
|
||||
}
|
||||
if let Ok(val) = header::HeaderValue::from_str(traceparent) {
|
||||
headers.insert(header::HeaderName::from_static(TRACE_PARENT_HEADER), val);
|
||||
}
|
||||
if let Ok(val) = header::HeaderValue::from_str(request_id) {
|
||||
headers.insert(header::HeaderName::from_static(REQUEST_ID_HEADER), val);
|
||||
}
|
||||
headers.insert(
|
||||
header::HeaderName::from_static("model"),
|
||||
header::HeaderValue::from_static("arch-router"),
|
||||
);
|
||||
|
||||
let Some((content, elapsed)) =
|
||||
post_and_extract_content(&self.client, &self.router_url, headers, body).await?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let parsed = self
|
||||
.router_model
|
||||
.parse_response(&content, &usage_preferences)?;
|
||||
|
||||
info!(
|
||||
content = %content.replace("\n", "\\n"),
|
||||
selected_model = ?parsed,
|
||||
response_time_ms = elapsed.as_millis(),
|
||||
"arch-router determined route"
|
||||
);
|
||||
|
||||
Ok(parsed)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,187 +0,0 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use common::{
|
||||
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
|
||||
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER},
|
||||
};
|
||||
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
|
||||
use hyper::header;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::router_model_v1::{self};
|
||||
|
||||
use super::router_model::RouterModel;
|
||||
|
||||
pub struct RouterService {
|
||||
router_url: String,
|
||||
client: reqwest::Client,
|
||||
router_model: Arc<dyn RouterModel>,
|
||||
#[allow(dead_code)]
|
||||
routing_provider_name: String,
|
||||
llm_usage_defined: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RoutingError {
|
||||
#[error("Failed to send request: {0}")]
|
||||
RequestError(#[from] reqwest::Error),
|
||||
|
||||
#[error("Failed to parse JSON: {0}, JSON: {1}")]
|
||||
JsonError(serde_json::Error, String),
|
||||
|
||||
#[error("Router model error: {0}")]
|
||||
RouterModelError(#[from] super::router_model::RoutingModelError),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, RoutingError>;
|
||||
|
||||
impl RouterService {
|
||||
pub fn new(
|
||||
providers: Vec<LlmProvider>,
|
||||
router_url: String,
|
||||
routing_model_name: String,
|
||||
routing_provider_name: String,
|
||||
) -> Self {
|
||||
let providers_with_usage = providers
|
||||
.iter()
|
||||
.filter(|provider| provider.routing_preferences.is_some())
|
||||
.cloned()
|
||||
.collect::<Vec<LlmProvider>>();
|
||||
|
||||
let llm_routes: HashMap<String, Vec<RoutingPreference>> = providers_with_usage
|
||||
.iter()
|
||||
.filter_map(|provider| {
|
||||
provider
|
||||
.routing_preferences
|
||||
.as_ref()
|
||||
.map(|prefs| (provider.name.clone(), prefs.clone()))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
|
||||
llm_routes,
|
||||
routing_model_name,
|
||||
router_model_v1::MAX_TOKEN_LEN,
|
||||
));
|
||||
|
||||
RouterService {
|
||||
router_url,
|
||||
client: reqwest::Client::new(),
|
||||
router_model,
|
||||
routing_provider_name,
|
||||
llm_usage_defined: !providers_with_usage.is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn determine_route(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
traceparent: &str,
|
||||
usage_preferences: Option<Vec<ModelUsagePreference>>,
|
||||
request_id: &str,
|
||||
) -> Result<Option<(String, String)>> {
|
||||
if messages.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if (usage_preferences.is_none() || usage_preferences.as_ref().unwrap().len() < 2)
|
||||
&& !self.llm_usage_defined
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let router_request = self
|
||||
.router_model
|
||||
.generate_request(messages, &usage_preferences);
|
||||
|
||||
debug!(
|
||||
model = %self.router_model.get_model_name(),
|
||||
endpoint = %self.router_url,
|
||||
"sending request to arch-router"
|
||||
);
|
||||
|
||||
debug!(
|
||||
body = %serde_json::to_string(&router_request).unwrap(),
|
||||
"arch router request"
|
||||
);
|
||||
|
||||
let mut llm_route_request_headers = header::HeaderMap::new();
|
||||
llm_route_request_headers.insert(
|
||||
header::CONTENT_TYPE,
|
||||
header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
llm_route_request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
|
||||
header::HeaderValue::from_str(&self.routing_provider_name).unwrap(),
|
||||
);
|
||||
|
||||
llm_route_request_headers.insert(
|
||||
header::HeaderName::from_static(TRACE_PARENT_HEADER),
|
||||
header::HeaderValue::from_str(traceparent).unwrap(),
|
||||
);
|
||||
|
||||
llm_route_request_headers.insert(
|
||||
header::HeaderName::from_static(REQUEST_ID_HEADER),
|
||||
header::HeaderValue::from_str(request_id).unwrap(),
|
||||
);
|
||||
|
||||
llm_route_request_headers.insert(
|
||||
header::HeaderName::from_static("model"),
|
||||
header::HeaderValue::from_static("arch-router"),
|
||||
);
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let res = self
|
||||
.client
|
||||
.post(&self.router_url)
|
||||
.headers(llm_route_request_headers)
|
||||
.body(serde_json::to_string(&router_request).unwrap())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let body = res.text().await?;
|
||||
let router_response_time = start_time.elapsed();
|
||||
|
||||
let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) {
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
error = %err,
|
||||
body = %serde_json::to_string(&body).unwrap(),
|
||||
"failed to parse json response"
|
||||
);
|
||||
return Err(RoutingError::JsonError(
|
||||
err,
|
||||
format!("Failed to parse JSON: {}", body),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if chat_completion_response.choices.is_empty() {
|
||||
warn!(body = %body, "no choices in router response");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if let Some(content) = &chat_completion_response.choices[0].message.content {
|
||||
let parsed_response = self
|
||||
.router_model
|
||||
.parse_response(content, &usage_preferences)?;
|
||||
info!(
|
||||
content = %content.replace("\n", "\\n"),
|
||||
selected_model = ?parsed_response,
|
||||
response_time_ms = router_response_time.as_millis(),
|
||||
"arch-router determined route"
|
||||
);
|
||||
|
||||
if let Some(ref parsed_response) = parsed_response {
|
||||
return Ok(Some(parsed_response.clone()));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
pub mod llm_router;
|
||||
pub(crate) mod http;
|
||||
pub mod llm;
|
||||
pub mod orchestrator;
|
||||
pub mod orchestrator_model;
|
||||
pub mod orchestrator_model_v1;
|
||||
pub mod plano_orchestrator;
|
||||
pub mod router_model;
|
||||
pub mod router_model_v1;
|
||||
|
|
|
|||
139
crates/brightstaff/src/router/orchestrator.rs
Normal file
139
crates/brightstaff/src/router/orchestrator.rs
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use common::{
|
||||
configuration::{AgentUsagePreference, OrchestrationPreference},
|
||||
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER},
|
||||
};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use hyper::header;
|
||||
use opentelemetry::global;
|
||||
use opentelemetry_http::HeaderInjector;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::http::{self, post_and_extract_content};
|
||||
use super::orchestrator_model::OrchestratorModel;
|
||||
|
||||
use crate::router::orchestrator_model_v1;
|
||||
|
||||
pub struct OrchestratorService {
|
||||
orchestrator_url: String,
|
||||
client: reqwest::Client,
|
||||
orchestrator_model: Arc<dyn OrchestratorModel>,
|
||||
orchestrator_provider_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OrchestrationError {
|
||||
#[error(transparent)]
|
||||
Http(#[from] http::HttpError),
|
||||
|
||||
#[error("Orchestrator model error: {0}")]
|
||||
OrchestratorModelError(#[from] super::orchestrator_model::OrchestratorModelError),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, OrchestrationError>;
|
||||
|
||||
impl OrchestratorService {
|
||||
pub fn new(
|
||||
orchestrator_url: String,
|
||||
orchestration_model_name: String,
|
||||
orchestrator_provider_name: String,
|
||||
) -> Self {
|
||||
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
|
||||
|
||||
let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model_name.clone(),
|
||||
orchestrator_model_v1::MAX_TOKEN_LEN,
|
||||
));
|
||||
|
||||
OrchestratorService {
|
||||
orchestrator_url,
|
||||
client: reqwest::Client::new(),
|
||||
orchestrator_model,
|
||||
orchestrator_provider_name,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn determine_orchestration(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
usage_preferences: Option<Vec<AgentUsagePreference>>,
|
||||
request_id: Option<String>,
|
||||
) -> Result<Option<Vec<(String, String)>>> {
|
||||
if messages.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if usage_preferences
|
||||
.as_ref()
|
||||
.is_none_or(|prefs| prefs.is_empty())
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let orchestrator_request = self
|
||||
.orchestrator_model
|
||||
.generate_request(messages, &usage_preferences);
|
||||
|
||||
debug!(
|
||||
model = %self.orchestrator_model.get_model_name(),
|
||||
endpoint = %self.orchestrator_url,
|
||||
"sending request to arch-orchestrator"
|
||||
);
|
||||
|
||||
let body = serde_json::to_string(&orchestrator_request)
|
||||
.map_err(super::orchestrator_model::OrchestratorModelError::from)?;
|
||||
debug!(body = %body, "arch orchestrator request");
|
||||
|
||||
let mut headers = header::HeaderMap::new();
|
||||
headers.insert(
|
||||
header::CONTENT_TYPE,
|
||||
header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
headers.insert(
|
||||
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
|
||||
header::HeaderValue::from_str(&self.orchestrator_provider_name)
|
||||
.unwrap_or_else(|_| header::HeaderValue::from_static("plano-orchestrator")),
|
||||
);
|
||||
|
||||
// Inject OpenTelemetry trace context from current span
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut headers));
|
||||
});
|
||||
|
||||
if let Some(ref request_id) = request_id {
|
||||
if let Ok(val) = header::HeaderValue::from_str(request_id) {
|
||||
headers.insert(header::HeaderName::from_static(REQUEST_ID_HEADER), val);
|
||||
}
|
||||
}
|
||||
|
||||
headers.insert(
|
||||
header::HeaderName::from_static("model"),
|
||||
header::HeaderValue::from_str(&self.orchestrator_provider_name)
|
||||
.unwrap_or_else(|_| header::HeaderValue::from_static("plano-orchestrator")),
|
||||
);
|
||||
|
||||
let Some((content, elapsed)) =
|
||||
post_and_extract_content(&self.client, &self.orchestrator_url, headers, body).await?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let parsed = self
|
||||
.orchestrator_model
|
||||
.parse_response(&content, &usage_preferences)?;
|
||||
|
||||
info!(
|
||||
content = %content.replace("\n", "\\n"),
|
||||
selected_routes = ?parsed,
|
||||
response_time_ms = elapsed.as_millis(),
|
||||
"arch-orchestrator determined routes"
|
||||
);
|
||||
|
||||
Ok(parsed)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,174 +0,0 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use common::{
|
||||
configuration::{AgentUsagePreference, OrchestrationPreference},
|
||||
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER},
|
||||
};
|
||||
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
|
||||
use hyper::header;
|
||||
use opentelemetry::global;
|
||||
use opentelemetry_http::HeaderInjector;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::orchestrator_model_v1::{self};
|
||||
|
||||
use super::orchestrator_model::OrchestratorModel;
|
||||
|
||||
pub struct OrchestratorService {
|
||||
orchestrator_url: String,
|
||||
client: reqwest::Client,
|
||||
orchestrator_model: Arc<dyn OrchestratorModel>,
|
||||
orchestrator_provider_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OrchestrationError {
|
||||
#[error("Failed to send request: {0}")]
|
||||
RequestError(#[from] reqwest::Error),
|
||||
|
||||
#[error("Failed to parse JSON: {0}, JSON: {1}")]
|
||||
JsonError(serde_json::Error, String),
|
||||
|
||||
#[error("Orchestrator model error: {0}")]
|
||||
OrchestratorModelError(#[from] super::orchestrator_model::OrchestratorModelError),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, OrchestrationError>;
|
||||
|
||||
impl OrchestratorService {
|
||||
pub fn new(
|
||||
orchestrator_url: String,
|
||||
orchestration_model_name: String,
|
||||
orchestrator_provider_name: String,
|
||||
) -> Self {
|
||||
// Empty agent orchestrations - will be provided via usage_preferences in requests
|
||||
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
|
||||
|
||||
let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model_name.clone(),
|
||||
orchestrator_model_v1::MAX_TOKEN_LEN,
|
||||
));
|
||||
|
||||
OrchestratorService {
|
||||
orchestrator_url,
|
||||
client: reqwest::Client::new(),
|
||||
orchestrator_model,
|
||||
orchestrator_provider_name,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn determine_orchestration(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
usage_preferences: Option<Vec<AgentUsagePreference>>,
|
||||
request_id: Option<String>,
|
||||
) -> Result<Option<Vec<(String, String)>>> {
|
||||
if messages.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Require usage_preferences to be provided
|
||||
if usage_preferences.is_none() || usage_preferences.as_ref().unwrap().is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let orchestrator_request = self
|
||||
.orchestrator_model
|
||||
.generate_request(messages, &usage_preferences);
|
||||
|
||||
debug!(
|
||||
model = %self.orchestrator_model.get_model_name(),
|
||||
endpoint = %self.orchestrator_url,
|
||||
"sending request to plano-orchestrator"
|
||||
);
|
||||
|
||||
debug!(
|
||||
body = %serde_json::to_string(&orchestrator_request).unwrap(),
|
||||
"plano orchestrator request"
|
||||
);
|
||||
|
||||
let mut orchestration_request_headers = header::HeaderMap::new();
|
||||
orchestration_request_headers.insert(
|
||||
header::CONTENT_TYPE,
|
||||
header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
orchestration_request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
|
||||
header::HeaderValue::from_str(&self.orchestrator_provider_name).unwrap(),
|
||||
);
|
||||
|
||||
// Inject OpenTelemetry trace context from current span
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut orchestration_request_headers));
|
||||
});
|
||||
|
||||
if let Some(request_id) = request_id {
|
||||
orchestration_request_headers.insert(
|
||||
header::HeaderName::from_static(REQUEST_ID_HEADER),
|
||||
header::HeaderValue::from_str(&request_id).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
orchestration_request_headers.insert(
|
||||
header::HeaderName::from_static("model"),
|
||||
header::HeaderValue::from_str(&self.orchestrator_provider_name).unwrap(),
|
||||
);
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let res = self
|
||||
.client
|
||||
.post(&self.orchestrator_url)
|
||||
.headers(orchestration_request_headers)
|
||||
.body(serde_json::to_string(&orchestrator_request).unwrap())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let body = res.text().await?;
|
||||
let orchestrator_response_time = start_time.elapsed();
|
||||
|
||||
let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) {
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
error = %err,
|
||||
body = %serde_json::to_string(&body).unwrap(),
|
||||
"failed to parse json response"
|
||||
);
|
||||
return Err(OrchestrationError::JsonError(
|
||||
err,
|
||||
format!("Failed to parse JSON: {}", body),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if chat_completion_response.choices.is_empty() {
|
||||
warn!(body = %body, "no choices in orchestrator response");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if let Some(content) = &chat_completion_response.choices[0].message.content {
|
||||
let parsed_response = self
|
||||
.orchestrator_model
|
||||
.parse_response(content, &usage_preferences)?;
|
||||
info!(
|
||||
content = %content.replace("\n", "\\n"),
|
||||
selected_routes = ?parsed_response,
|
||||
response_time_ms = orchestrator_response_time.as_millis(),
|
||||
"arch-orchestrator determined routes"
|
||||
);
|
||||
|
||||
if let Some(ref parsed_response) = parsed_response {
|
||||
return Ok(Some(parsed_response.clone()));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -88,35 +88,18 @@ pub trait StateStorage: Send + Sync {
|
|||
combined_input.extend(current_input);
|
||||
|
||||
debug!(
|
||||
"PLANO | BRIGHTSTAFF | STATE_STORAGE | RESP_ID:{} | Merged state: prev_items={}, current_items={}, total_items={}, combined_json={}",
|
||||
prev_state.response_id,
|
||||
prev_count,
|
||||
current_count,
|
||||
combined_input.len(),
|
||||
serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string())
|
||||
response_id = %prev_state.response_id,
|
||||
prev_items = prev_count,
|
||||
current_items = current_count,
|
||||
total_items = combined_input.len(),
|
||||
combined_json = %serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string()),
|
||||
"merged conversation state"
|
||||
);
|
||||
|
||||
combined_input
|
||||
}
|
||||
}
|
||||
|
||||
/// Storage backend type enum
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum StorageBackend {
|
||||
Memory,
|
||||
Supabase,
|
||||
}
|
||||
|
||||
impl StorageBackend {
|
||||
pub fn parse_backend(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"memory" => Some(StorageBackend::Memory),
|
||||
"supabase" => Some(StorageBackend::Supabase),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === Utility functions for state management ===
|
||||
|
||||
/// Extract input items from InputParam, converting text to structured format
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ use std::io::Read;
|
|||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::handlers::streaming::StreamProcessor;
|
||||
use crate::state::{OpenAIConversationState, StateStorage};
|
||||
use crate::streaming::StreamProcessor;
|
||||
|
||||
/// Processor that wraps another processor and handles v1/responses state management
|
||||
/// Captures response_id and output from streaming responses, stores state after completion
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ use tokio_stream::StreamExt;
|
|||
use tracing::{debug, info, warn, Instrument};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
use super::pipeline_processor::{PipelineError, PipelineProcessor};
|
||||
use crate::handlers::agents::pipeline::{PipelineError, PipelineProcessor};
|
||||
|
||||
const STREAM_BUFFER_SIZE: usize = 16;
|
||||
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
|
||||
|
|
@ -52,6 +52,7 @@ pub fn collect_custom_trace_attributes(
|
|||
attributes
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn append_span_attributes(
|
||||
mut span_builder: SpanBuilder,
|
||||
attributes: &HashMap<String, String>,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ use tracing_subscriber::registry::LookupSpan;
|
|||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
use crate::tracing::ServiceNameOverrideExporter;
|
||||
use super::ServiceNameOverrideExporter;
|
||||
use common::configuration::Tracing;
|
||||
|
||||
struct BracketedTime;
|
||||
|
|
@ -96,17 +96,16 @@ pub fn init_tracer(tracing_config: Option<&Tracing>) -> &'static SdkTracerProvid
|
|||
tracing_enabled, otel_endpoint, random_sampling
|
||||
);
|
||||
|
||||
// Create OTLP exporter to send spans to collector
|
||||
if tracing_enabled {
|
||||
// Set service name via environment if not already set
|
||||
// Create OTLP exporter to send spans to collector.
|
||||
// Use `if let` to destructure the endpoint, avoiding an unwrap.
|
||||
if let Some(endpoint) = otel_endpoint.as_deref().filter(|_| tracing_enabled) {
|
||||
if std::env::var("OTEL_SERVICE_NAME").is_err() {
|
||||
std::env::set_var("OTEL_SERVICE_NAME", "plano");
|
||||
}
|
||||
|
||||
// Create ServiceNameOverrideExporter to support per-span service names
|
||||
// This allows spans to have different service names (e.g., plano(orchestrator),
|
||||
// plano(filter), plano(llm)) by setting the "service.name.override" attribute
|
||||
let exporter = ServiceNameOverrideExporter::new(otel_endpoint.as_ref().unwrap());
|
||||
let exporter = ServiceNameOverrideExporter::new(endpoint);
|
||||
|
||||
let provider = SdkTracerProvider::builder()
|
||||
.with_batch_exporter(exporter)
|
||||
|
|
@ -1,11 +1,13 @@
|
|||
mod constants;
|
||||
mod custom_attributes;
|
||||
mod init;
|
||||
mod service_name_exporter;
|
||||
|
||||
pub use constants::{
|
||||
error, http, llm, operation_component, routing, signals, OperationNameBuilder,
|
||||
};
|
||||
pub use custom_attributes::{append_span_attributes, collect_custom_trace_attributes};
|
||||
pub use custom_attributes::collect_custom_trace_attributes;
|
||||
pub use init::init_tracer;
|
||||
pub use service_name_exporter::{ServiceNameOverrideExporter, SERVICE_NAME_OVERRIDE_KEY};
|
||||
|
||||
use opentelemetry::trace::get_active_span;
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
pub mod tracing;
|
||||
Loading…
Add table
Add a link
Reference in a new issue