mirror of
https://github.com/katanemo/plano.git
synced 2026-05-07 23:02:43 +02:00
Use mcp tools for filter chain (#621)
* agents framework demo * more changes * add more changes * pending changes * fix tests * fix more * rebase with main and better handle error from mcp * add trace for filters * add test for client error, server error and for mcp error * update schema validate code and rename kind => type in agent_filter * fix agent description and pre-commit * fix tests * add provider specific request parsing in agents chat * fix precommit and tests * cleanup demo * update readme * fix pre-commit * refactor tracing * fix fmt * fix: handle MessageContent enum in responses API conversion - Update request.rs to handle new MessageContent enum structure from main - MessageContent can now be Text(String) or Items(Vec<InputContent>) - Handle new InputItem variants (ItemReference, FunctionCallOutput) - Fixes compilation error after merging latest main (#632) * address pr feedback * fix span * fix build * update openai version
This commit is contained in:
parent
cb82a83c7b
commit
2f9121407b
40 changed files with 4886 additions and 190 deletions
|
|
@ -1,16 +1,24 @@
|
|||
use std::sync::Arc;
|
||||
use std::time::{Instant, SystemTime};
|
||||
|
||||
use bytes::Bytes;
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use common::consts::TRACE_PARENT_HEADER;
|
||||
use common::traces::{SpanBuilder, SpanKind, parse_traceparent, generate_random_span_id};
|
||||
use hermesllm::apis::OpenAIMessage;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
use hermesllm::providers::request::ProviderRequest;
|
||||
use hermesllm::ProviderRequestType;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::{Request, Response};
|
||||
use serde::ser::Error as SerError;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::agent_selector::{AgentSelectionError, AgentSelector};
|
||||
use super::pipeline_processor::{PipelineError, PipelineProcessor};
|
||||
use super::response_handler::ResponseHandler;
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::tracing::{OperationNameBuilder, operation_component, http};
|
||||
|
||||
/// Main errors for agent chat completions
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
|
@ -33,8 +41,17 @@ pub async fn agent_chat(
|
|||
_: String,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
trace_collector: Arc<common::traces::TraceCollector>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
match handle_agent_chat(request, router_service, agents_list, listeners).await {
|
||||
match handle_agent_chat(
|
||||
request,
|
||||
router_service,
|
||||
agents_list,
|
||||
listeners,
|
||||
trace_collector,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
// Check if this is a client error from the pipeline that should be cascaded
|
||||
|
|
@ -109,10 +126,11 @@ async fn handle_agent_chat(
|
|||
router_service: Arc<RouterService>,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
trace_collector: Arc<common::traces::TraceCollector>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
|
||||
// Initialize services
|
||||
let agent_selector = AgentSelector::new(router_service);
|
||||
let pipeline_processor = PipelineProcessor::default();
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let response_handler = ResponseHandler::new();
|
||||
|
||||
// Extract listener name from headers
|
||||
|
|
@ -132,6 +150,13 @@ async fn handle_agent_chat(
|
|||
info!("Handling request for listener: {}", listener.name);
|
||||
|
||||
// Parse request body
|
||||
let request_path = request
|
||||
.uri()
|
||||
.path()
|
||||
.to_string()
|
||||
.strip_prefix("/agents")
|
||||
.unwrap()
|
||||
.to_string();
|
||||
let request_headers = request.headers().clone();
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
|
|
@ -140,61 +165,141 @@ async fn handle_agent_chat(
|
|||
String::from_utf8_lossy(&chat_request_bytes)
|
||||
);
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest =
|
||||
serde_json::from_slice(&chat_request_bytes).map_err(|err| {
|
||||
warn!(
|
||||
"Failed to parse request body as ChatCompletionsRequest: {}",
|
||||
err
|
||||
);
|
||||
AgentFilterChainError::RequestParsing(err)
|
||||
// Determine the API type from the endpoint
|
||||
let api_type =
|
||||
SupportedAPIsFromClient::from_endpoint(request_path.as_str()).ok_or_else(|| {
|
||||
let err_msg = format!("Unsupported endpoint: {}", request_path);
|
||||
warn!("{}", err_msg);
|
||||
AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg))
|
||||
})?;
|
||||
|
||||
let client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse request as ProviderRequestType: {}", err);
|
||||
let err_msg = format!("Failed to parse request: {}", err);
|
||||
return Err(AgentFilterChainError::RequestParsing(
|
||||
serde_json::Error::custom(err_msg),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let message: Vec<OpenAIMessage> = client_request.get_messages();
|
||||
|
||||
// let chat_completions_request: ChatCompletionsRequest =
|
||||
// serde_json::from_slice(&chat_request_bytes).map_err(|err| {
|
||||
// warn!(
|
||||
// "Failed to parse request body as ChatCompletionsRequest: {}",
|
||||
// err
|
||||
// );
|
||||
// AgentFilterChainError::RequestParsing(err)
|
||||
// })?;
|
||||
|
||||
// Extract trace parent for routing
|
||||
let trace_parent = request_headers
|
||||
.iter()
|
||||
.find(|(key, _)| key.as_str() == "traceparent")
|
||||
.find(|(key, _)| key.as_str() == TRACE_PARENT_HEADER)
|
||||
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
|
||||
|
||||
// Select appropriate agent using arch router llm model
|
||||
let selected_agent = agent_selector
|
||||
.select_agent(&chat_completions_request.messages, &listener, trace_parent)
|
||||
.await?;
|
||||
|
||||
debug!("Processing agent pipeline: {}", selected_agent.id);
|
||||
|
||||
// Create agent map for pipeline processing
|
||||
// Create agent map for pipeline processing and agent selection
|
||||
let agent_map = {
|
||||
let agents = agents_list.read().await;
|
||||
let agents = agents.as_ref().unwrap();
|
||||
agent_selector.create_agent_map(agents)
|
||||
};
|
||||
|
||||
// Parse trace parent to get trace_id and parent_span_id
|
||||
let (trace_id, parent_span_id) = if let Some(ref tp) = trace_parent {
|
||||
parse_traceparent(tp)
|
||||
} else {
|
||||
(String::new(), None)
|
||||
};
|
||||
|
||||
// Select appropriate agent using arch router llm model
|
||||
let selected_agent = agent_selector
|
||||
.select_agent(&message, &listener, trace_parent.clone())
|
||||
.await?;
|
||||
|
||||
debug!("Processing agent pipeline: {}", selected_agent.id);
|
||||
|
||||
// Record the start time for agent span
|
||||
let agent_start_time = SystemTime::now();
|
||||
let agent_start_instant = Instant::now();
|
||||
// let (span_id, trace_id) = trace_collector.start_span(
|
||||
// trace_parent.clone(),
|
||||
// operation_component::AGENT,
|
||||
// &format!("/agents{}", request_path),
|
||||
// &selected_agent.id,
|
||||
// );
|
||||
|
||||
let span_id = generate_random_span_id();
|
||||
|
||||
// Process the filter chain
|
||||
let processed_messages = pipeline_processor
|
||||
let chat_history = pipeline_processor
|
||||
.process_filter_chain(
|
||||
&chat_completions_request,
|
||||
&message,
|
||||
&selected_agent,
|
||||
&agent_map,
|
||||
&request_headers,
|
||||
Some(&trace_collector),
|
||||
trace_id.clone(),
|
||||
span_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Get terminal agent and send final response
|
||||
let terminal_agent_name = selected_agent.id;
|
||||
let terminal_agent_name = selected_agent.id.clone();
|
||||
let terminal_agent = agent_map.get(&terminal_agent_name).unwrap();
|
||||
|
||||
debug!("Processing terminal agent: {}", terminal_agent_name);
|
||||
debug!("Terminal agent details: {:?}", terminal_agent);
|
||||
|
||||
let llm_response = pipeline_processor
|
||||
.invoke_upstream_agent(
|
||||
&processed_messages,
|
||||
&chat_completions_request,
|
||||
.invoke_agent(
|
||||
&chat_history,
|
||||
client_request,
|
||||
terminal_agent,
|
||||
&request_headers,
|
||||
trace_id.clone(),
|
||||
span_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Record agent span after processing is complete
|
||||
let agent_end_time = SystemTime::now();
|
||||
let agent_elapsed = agent_start_instant.elapsed();
|
||||
|
||||
// Build full path with /agents prefix
|
||||
let full_path = format!("/agents{}", request_path);
|
||||
|
||||
// Build operation name: POST {full_path} {agent_name}
|
||||
let operation_name = OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path(&full_path)
|
||||
.with_target(&terminal_agent_name)
|
||||
.build();
|
||||
|
||||
let mut span_builder = SpanBuilder::new(&operation_name)
|
||||
.with_span_id(span_id)
|
||||
.with_kind(SpanKind::Internal)
|
||||
.with_start_time(agent_start_time)
|
||||
.with_end_time(agent_end_time)
|
||||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::TARGET, full_path)
|
||||
.with_attribute("agent.name", terminal_agent_name.clone())
|
||||
.with_attribute("duration_ms", format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0));
|
||||
|
||||
if !trace_id.is_empty() {
|
||||
span_builder = span_builder.with_trace_id(trace_id);
|
||||
}
|
||||
if let Some(parent_id) = parent_span_id {
|
||||
span_builder = span_builder.with_parent_span_id(parent_id);
|
||||
}
|
||||
|
||||
let span = span_builder.build();
|
||||
// Use plano(agent) as service name for the agent processing span
|
||||
trace_collector.record_span(operation_component::AGENT, span);
|
||||
|
||||
// Create streaming response
|
||||
response_handler
|
||||
.create_streaming_response(llm_response)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ pub enum AgentSelectionError {
|
|||
RoutingError(String),
|
||||
#[error("Default agent not found for listener: {0}")]
|
||||
DefaultAgentNotFound(String),
|
||||
#[error("MCP client error: {0}")]
|
||||
McpError(String),
|
||||
}
|
||||
|
||||
/// Service for selecting agents based on routing preferences and listener configuration
|
||||
|
|
@ -29,7 +31,9 @@ pub struct AgentSelector {
|
|||
|
||||
impl AgentSelector {
|
||||
pub fn new(router_service: Arc<RouterService>) -> Self {
|
||||
Self { router_service }
|
||||
Self {
|
||||
router_service,
|
||||
}
|
||||
}
|
||||
|
||||
/// Find listener by name from the request headers
|
||||
|
|
@ -77,7 +81,9 @@ impl AgentSelector {
|
|||
return Ok(agents[0].clone());
|
||||
}
|
||||
|
||||
let usage_preferences = self.convert_agent_description_to_routing_preferences(agents);
|
||||
let usage_preferences = self
|
||||
.convert_agent_description_to_routing_preferences(agents)
|
||||
.await;
|
||||
debug!(
|
||||
"Agents usage preferences for agent routing str: {}",
|
||||
serde_json::to_string(&usage_preferences).unwrap_or_default()
|
||||
|
|
@ -131,20 +137,23 @@ impl AgentSelector {
|
|||
}
|
||||
|
||||
/// Convert agent descriptions to routing preferences
|
||||
fn convert_agent_description_to_routing_preferences(
|
||||
async fn convert_agent_description_to_routing_preferences(
|
||||
&self,
|
||||
agents: &[AgentFilterChain],
|
||||
) -> Vec<ModelUsagePreference> {
|
||||
agents
|
||||
.iter()
|
||||
.map(|agent| ModelUsagePreference {
|
||||
model: agent.id.clone(),
|
||||
let mut preferences = Vec::new();
|
||||
|
||||
for agent_chain in agents {
|
||||
preferences.push(ModelUsagePreference {
|
||||
model: agent_chain.id.clone(),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
name: agent.id.clone(),
|
||||
description: agent.description.as_ref().unwrap_or(&String::new()).clone(),
|
||||
name: agent_chain.id.clone(),
|
||||
description: agent_chain.description.clone().unwrap_or_default(),
|
||||
}],
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
}
|
||||
|
||||
preferences
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -183,8 +192,10 @@ mod tests {
|
|||
fn create_test_agent_struct(name: &str) -> Agent {
|
||||
Agent {
|
||||
id: name.to_string(),
|
||||
kind: Some("test".to_string()),
|
||||
agent_type: Some("test".to_string()),
|
||||
url: "http://localhost:8080".to_string(),
|
||||
tool: None,
|
||||
transport: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -240,8 +251,8 @@ mod tests {
|
|||
assert!(agent_map.contains_key("agent2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_agent_description_to_routing_preferences() {
|
||||
#[tokio::test]
|
||||
async fn test_convert_agent_description_to_routing_preferences() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
|
||||
|
|
@ -250,7 +261,9 @@ mod tests {
|
|||
create_test_agent("agent2", "Second agent description", false),
|
||||
];
|
||||
|
||||
let preferences = selector.convert_agent_description_to_routing_preferences(&agents);
|
||||
let preferences = selector
|
||||
.convert_agent_description_to_routing_preferences(&agents)
|
||||
.await;
|
||||
|
||||
assert_eq!(preferences.len(), 2);
|
||||
assert_eq!(preferences[0].model, "agent1");
|
||||
|
|
|
|||
|
|
@ -42,19 +42,23 @@ mod integration_tests {
|
|||
// Setup services
|
||||
let router_service = create_test_router_service();
|
||||
let agent_selector = AgentSelector::new(router_service);
|
||||
let pipeline_processor = PipelineProcessor::default();
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
|
||||
// Create test data
|
||||
let agents = vec![
|
||||
Agent {
|
||||
id: "filter-agent".to_string(),
|
||||
kind: Some("filter".to_string()),
|
||||
agent_type: Some("filter".to_string()),
|
||||
url: "http://localhost:8081".to_string(),
|
||||
tool: None,
|
||||
transport: None,
|
||||
},
|
||||
Agent {
|
||||
id: "terminal-agent".to_string(),
|
||||
kind: Some("terminal".to_string()),
|
||||
agent_type: Some("terminal".to_string()),
|
||||
url: "http://localhost:8082".to_string(),
|
||||
tool: None,
|
||||
transport: None,
|
||||
},
|
||||
];
|
||||
|
||||
|
|
@ -107,7 +111,15 @@ mod integration_tests {
|
|||
|
||||
let headers = HeaderMap::new();
|
||||
let result = pipeline_processor
|
||||
.process_filter_chain(&request, &test_pipeline, &agent_map, &headers)
|
||||
.process_filter_chain(
|
||||
&request.messages,
|
||||
&test_pipeline,
|
||||
&agent_map,
|
||||
&headers,
|
||||
None,
|
||||
String::new(),
|
||||
String::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
println!("Pipeline processing result: {:?}", result);
|
||||
|
|
|
|||
49
crates/brightstaff/src/handlers/jsonrpc.rs
Normal file
49
crates/brightstaff/src/handlers/jsonrpc.rs
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub const JSON_RPC_VERSION: &str = "2.0";
|
||||
pub const TOOL_CALL_METHOD : &str = "tools/call";
|
||||
pub const MCP_INITIALIZE: &str = "initialize";
|
||||
pub const MCP_INITIALIZE_NOTIFICATION: &str = "initialize/notification";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum JsonRpcId {
|
||||
String(String),
|
||||
Number(u64),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcRequest {
|
||||
pub jsonrpc: String,
|
||||
pub id: JsonRpcId,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcNotification {
|
||||
pub jsonrpc: String,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcError {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcResponse {
|
||||
pub jsonrpc: String,
|
||||
pub id: JsonRpcId,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<HashMap<String, serde_json::Value>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<JsonRpcError>,
|
||||
}
|
||||
|
|
@ -7,6 +7,7 @@ pub mod function_calling;
|
|||
pub mod pipeline_processor;
|
||||
pub mod response_handler;
|
||||
pub mod utils;
|
||||
pub mod jsonrpc;
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration_tests;
|
||||
|
|
|
|||
|
|
@ -1,10 +1,24 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use common::configuration::{Agent, AgentFilterChain};
|
||||
use common::consts::{ARCH_UPSTREAM_HOST_HEADER, ENVOY_RETRY_HEADER};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
|
||||
use common::consts::{
|
||||
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
|
||||
};
|
||||
use common::traces::{SpanBuilder, SpanKind, generate_random_span_id};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use hyper::header::HeaderMap;
|
||||
use tracing::{debug, warn};
|
||||
use std::time::{Instant, SystemTime};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::tracing::operation_component::{self};
|
||||
use crate::tracing::{http, OperationNameBuilder};
|
||||
|
||||
use crate::handlers::jsonrpc::{
|
||||
JsonRpcId, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, JSON_RPC_VERSION,
|
||||
MCP_INITIALIZE, MCP_INITIALIZE_NOTIFICATION, TOOL_CALL_METHOD,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Errors that can occur during pipeline processing
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
|
@ -19,6 +33,12 @@ pub enum PipelineError {
|
|||
NoChoicesInResponse(String),
|
||||
#[error("No content in response from agent '{0}'")]
|
||||
NoContentInResponse(String),
|
||||
#[error("No result in response from agent '{0}'")]
|
||||
NoResultInResponse(String),
|
||||
#[error("No structured content in response from agent '{0}'")]
|
||||
NoStructuredContentInResponse(String),
|
||||
#[error("No messages in response from agent '{0}'")]
|
||||
NoMessagesInResponse(String),
|
||||
#[error("Client error from agent '{agent}' (HTTP {status}): {body}")]
|
||||
ClientError {
|
||||
agent: String,
|
||||
|
|
@ -37,13 +57,17 @@ pub enum PipelineError {
|
|||
pub struct PipelineProcessor {
|
||||
client: reqwest::Client,
|
||||
url: String,
|
||||
agent_id_session_map: HashMap<String, String>,
|
||||
}
|
||||
|
||||
const ENVOY_API_ROUTER_ADDRESS: &str = "http://localhost:11000";
|
||||
|
||||
impl Default for PipelineProcessor {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
url: "http://localhost:11000/v1/chat/completions".to_string(),
|
||||
url: ENVOY_API_ROUTER_ADDRESS.to_string(),
|
||||
agent_id_session_map: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -53,18 +77,128 @@ impl PipelineProcessor {
|
|||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
url,
|
||||
agent_id_session_map: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a span for filter execution
|
||||
fn record_filter_span(
|
||||
&self,
|
||||
collector: &std::sync::Arc<common::traces::TraceCollector>,
|
||||
agent_name: &str,
|
||||
tool_name: &str,
|
||||
start_time: SystemTime,
|
||||
end_time: SystemTime,
|
||||
elapsed: std::time::Duration,
|
||||
trace_id: String,
|
||||
parent_span_id: String,
|
||||
span_id: String,
|
||||
) -> String {
|
||||
// let (trace_id, parent_span_id) = self.extract_trace_context();
|
||||
|
||||
// Build operation name: POST /agents/* {filter_name}
|
||||
// Using generic path since we don't have access to specific endpoint here
|
||||
let operation_name = OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path("/agents/*")
|
||||
.with_target(agent_name)
|
||||
.build();
|
||||
|
||||
let mut span_builder = SpanBuilder::new(&operation_name)
|
||||
.with_span_id(span_id.clone())
|
||||
.with_kind(SpanKind::Client)
|
||||
.with_start_time(start_time)
|
||||
.with_end_time(end_time)
|
||||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::TARGET, "/agents/*")
|
||||
.with_attribute("filter.name", agent_name.to_string())
|
||||
.with_attribute("filter.tool_name", tool_name.to_string())
|
||||
.with_attribute(
|
||||
"duration_ms",
|
||||
format!("{:.2}", elapsed.as_secs_f64() * 1000.0),
|
||||
);
|
||||
|
||||
if !trace_id.is_empty() {
|
||||
span_builder = span_builder.with_trace_id(trace_id);
|
||||
}
|
||||
if !parent_span_id.is_empty() {
|
||||
span_builder = span_builder.with_parent_span_id(parent_span_id);
|
||||
}
|
||||
|
||||
let span = span_builder.build();
|
||||
// Use plano(filter) as service name for filter execution spans
|
||||
collector.record_span(operation_component::AGENT_FILTER, span);
|
||||
span_id.clone()
|
||||
}
|
||||
|
||||
/// Record a span for MCP protocol interactions
|
||||
fn record_mcp_span(
|
||||
&self,
|
||||
collector: &std::sync::Arc<common::traces::TraceCollector>,
|
||||
operation: &str,
|
||||
agent_id: &str,
|
||||
start_time: SystemTime,
|
||||
end_time: SystemTime,
|
||||
elapsed: std::time::Duration,
|
||||
additional_attrs: Option<HashMap<&str, String>>,
|
||||
trace_id: String,
|
||||
parent_span_id: String,
|
||||
span_id: Option<String>,
|
||||
) {
|
||||
// let (trace_id, parent_span_id) = self.extract_trace_context();
|
||||
|
||||
// Build operation name: POST /mcp {agent_id}
|
||||
let operation_name = OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path("/mcp")
|
||||
.with_operation(operation)
|
||||
.with_target(agent_id)
|
||||
.build();
|
||||
|
||||
let mut span_builder = SpanBuilder::new(&operation_name)
|
||||
.with_span_id(span_id.unwrap_or_else(|| generate_random_span_id()))
|
||||
.with_kind(SpanKind::Client)
|
||||
.with_start_time(start_time)
|
||||
.with_end_time(end_time)
|
||||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::TARGET, &format!("/mcp ({})", operation.to_string()))
|
||||
.with_attribute("mcp.operation", operation.to_string())
|
||||
.with_attribute("mcp.agent_id", agent_id.to_string())
|
||||
.with_attribute(
|
||||
"duration_ms",
|
||||
format!("{:.2}", elapsed.as_secs_f64() * 1000.0),
|
||||
);
|
||||
|
||||
if let Some(attrs) = additional_attrs {
|
||||
for (key, value) in attrs {
|
||||
span_builder = span_builder.with_attribute(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
if !trace_id.is_empty() {
|
||||
span_builder = span_builder.with_trace_id(trace_id);
|
||||
}
|
||||
if !parent_span_id.is_empty() {
|
||||
span_builder = span_builder.with_parent_span_id(parent_span_id);
|
||||
}
|
||||
|
||||
let span = span_builder.build();
|
||||
// MCP spans also use plano(filter) service name as they are part of filter operations
|
||||
collector.record_span(operation_component::AGENT_FILTER, span);
|
||||
}
|
||||
|
||||
/// Process the filter chain of agents (all except the terminal agent)
|
||||
pub async fn process_filter_chain(
|
||||
&self,
|
||||
initial_request: &ChatCompletionsRequest,
|
||||
&mut self,
|
||||
chat_history: &[Message],
|
||||
agent_filter_chain: &AgentFilterChain,
|
||||
agent_map: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
trace_collector: Option<&std::sync::Arc<common::traces::TraceCollector>>,
|
||||
trace_id: String,
|
||||
parent_span_id: String,
|
||||
) -> Result<Vec<Message>, PipelineError> {
|
||||
let mut chat_completions_history = initial_request.messages.clone();
|
||||
let mut chat_history_updated = chat_history.to_vec();
|
||||
|
||||
for agent_name in &agent_filter_chain.filter_chain {
|
||||
debug!("Processing filter agent: {}", agent_name);
|
||||
|
|
@ -73,123 +207,490 @@ impl PipelineProcessor {
|
|||
.get(agent_name)
|
||||
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
|
||||
|
||||
debug!("Agent details: {:?}", agent);
|
||||
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
|
||||
|
||||
let response_content = self
|
||||
.send_agent_filter_chain_request(
|
||||
&chat_completions_history,
|
||||
initial_request,
|
||||
info!(
|
||||
"executing filter: {}/{}, url: {}, conversation length: {}",
|
||||
agent_name,
|
||||
tool_name,
|
||||
agent.url,
|
||||
chat_history.len()
|
||||
);
|
||||
|
||||
let start_time = SystemTime::now();
|
||||
let start_instant = Instant::now();
|
||||
|
||||
// Generate filter span ID before execution so MCP spans can use it as parent
|
||||
let filter_span_id = generate_random_span_id();
|
||||
|
||||
chat_history_updated = self
|
||||
.execute_filter(
|
||||
&chat_history_updated,
|
||||
agent,
|
||||
request_headers,
|
||||
trace_collector,
|
||||
trace_id.clone(),
|
||||
filter_span_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
debug!("Received response from filter agent {}", agent_name);
|
||||
let end_time = SystemTime::now();
|
||||
let elapsed = start_instant.elapsed();
|
||||
|
||||
// Parse the response content as new message history
|
||||
chat_completions_history =
|
||||
serde_json::from_str(&response_content).inspect_err(|err| {
|
||||
warn!(
|
||||
"Failed to parse response from agent {}, err: {}, response: {}",
|
||||
agent_name, err, response_content
|
||||
)
|
||||
})?;
|
||||
info!(
|
||||
"Filter '{}' completed in {:.2}ms, updated conversation length: {}",
|
||||
agent_name,
|
||||
elapsed.as_secs_f64() * 1000.0,
|
||||
chat_history_updated.len()
|
||||
);
|
||||
|
||||
// Record span for this filter execution
|
||||
if let Some(collector) = trace_collector {
|
||||
self.record_filter_span(
|
||||
collector,
|
||||
agent_name,
|
||||
tool_name,
|
||||
start_time,
|
||||
end_time,
|
||||
elapsed,
|
||||
trace_id.clone(),
|
||||
parent_span_id.clone(),
|
||||
filter_span_id,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(chat_completions_history)
|
||||
Ok(chat_history_updated)
|
||||
}
|
||||
|
||||
/// Send request to a specific agent and return the response content
|
||||
async fn send_agent_filter_chain_request(
|
||||
/// Build common MCP headers for requests
|
||||
fn build_mcp_headers(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
original_request: &ChatCompletionsRequest,
|
||||
agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<String, PipelineError> {
|
||||
let mut request = original_request.clone();
|
||||
request.messages = messages.to_vec();
|
||||
agent_id: &str,
|
||||
session_id: Option<&str>,
|
||||
trace_id: String,
|
||||
parent_span_id: String,
|
||||
) -> Result<HeaderMap, PipelineError> {
|
||||
let trace_parent = format!("00-{}-{}-01", trace_id, parent_span_id);
|
||||
let mut headers = request_headers.clone();
|
||||
headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
let request_body = serde_json::to_string(&request)?;
|
||||
debug!("Sending request to agent {}", agent.id);
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&agent.id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
|
||||
headers.remove(TRACE_PARENT_HEADER);
|
||||
headers.insert(
|
||||
TRACE_PARENT_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&trace_parent).unwrap(),
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(agent_id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent_id.to_string()))?,
|
||||
);
|
||||
|
||||
headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
|
||||
headers.insert(
|
||||
"Accept",
|
||||
hyper::header::HeaderValue::from_static("application/json, text/event-stream"),
|
||||
);
|
||||
|
||||
headers.insert(
|
||||
"Content-Type",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
if let Some(sid) = session_id {
|
||||
headers.insert(
|
||||
"mcp-session-id",
|
||||
hyper::header::HeaderValue::from_str(sid).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
/// Parse SSE formatted response and extract JSON-RPC data
|
||||
fn parse_sse_response(
|
||||
&self,
|
||||
response_bytes: &[u8],
|
||||
agent_id: &str,
|
||||
) -> Result<String, PipelineError> {
|
||||
let response_str = String::from_utf8_lossy(response_bytes);
|
||||
let lines: Vec<&str> = response_str.lines().collect();
|
||||
|
||||
// Validate SSE format: first line should be "event: message"
|
||||
if lines.is_empty() || lines[0] != "event: message" {
|
||||
warn!(
|
||||
"Invalid SSE response format from agent {}: expected 'event: message' as first line, got: {:?}",
|
||||
agent_id,
|
||||
lines.first()
|
||||
);
|
||||
return Err(PipelineError::NoContentInResponse(format!(
|
||||
"Invalid SSE response format from agent {}: expected 'event: message' as first line",
|
||||
agent_id
|
||||
)));
|
||||
}
|
||||
|
||||
// Find the data line
|
||||
let data_lines: Vec<&str> = lines
|
||||
.iter()
|
||||
.filter(|line| line.starts_with("data: "))
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
if data_lines.len() != 1 {
|
||||
warn!(
|
||||
"Expected exactly one 'data:' line from agent {}, found {}",
|
||||
agent_id,
|
||||
data_lines.len()
|
||||
);
|
||||
return Err(PipelineError::NoContentInResponse(format!(
|
||||
"Expected exactly one 'data:' line from agent {}, found {}",
|
||||
agent_id,
|
||||
data_lines.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Skip "data: " prefix
|
||||
Ok(data_lines[0][6..].to_string())
|
||||
}
|
||||
|
||||
/// Send an MCP request and return the response
|
||||
async fn send_mcp_request(
|
||||
&self,
|
||||
json_rpc_request: &JsonRpcRequest,
|
||||
headers: HeaderMap,
|
||||
agent_id: &str,
|
||||
) -> Result<reqwest::Response, PipelineError> {
|
||||
let request_body = serde_json::to_string(json_rpc_request)?;
|
||||
|
||||
debug!(
|
||||
"Sending MCP request to agent {}: {}",
|
||||
agent_id, request_body
|
||||
);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.url)
|
||||
.headers(agent_headers)
|
||||
.post(format!("{}/mcp", self.url))
|
||||
.headers(headers)
|
||||
.body(request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Build a tools/call JSON-RPC request
|
||||
fn build_tool_call_request(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
messages: &[Message],
|
||||
) -> Result<JsonRpcRequest, PipelineError> {
|
||||
let mut arguments = HashMap::new();
|
||||
arguments.insert("messages".to_string(), serde_json::to_value(messages)?);
|
||||
|
||||
let mut params = HashMap::new();
|
||||
params.insert("name".to_string(), serde_json::to_value(tool_name)?);
|
||||
params.insert("arguments".to_string(), serde_json::to_value(arguments)?);
|
||||
|
||||
Ok(JsonRpcRequest {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
id: JsonRpcId::String(Uuid::new_v4().to_string()),
|
||||
method: TOOL_CALL_METHOD.to_string(),
|
||||
params: Some(params),
|
||||
})
|
||||
}
|
||||
|
||||
/// Send request to a specific agent and return the response content
|
||||
async fn execute_filter(
|
||||
&mut self,
|
||||
messages: &[Message],
|
||||
agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
trace_collector: Option<&std::sync::Arc<common::traces::TraceCollector>>,
|
||||
trace_id: String,
|
||||
filter_span_id: String,
|
||||
) -> Result<Vec<Message>, PipelineError> {
|
||||
// Get or create MCP session
|
||||
let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) {
|
||||
session_id.clone()
|
||||
} else {
|
||||
let session_id = self
|
||||
.get_new_session_id(
|
||||
&agent.id,
|
||||
trace_id.clone(),
|
||||
filter_span_id.clone(),
|
||||
)
|
||||
.await;
|
||||
self.agent_id_session_map
|
||||
.insert(agent.id.clone(), session_id.clone());
|
||||
session_id
|
||||
};
|
||||
|
||||
info!(
|
||||
"Using MCP session ID {} for agent {}",
|
||||
mcp_session_id, agent.id
|
||||
);
|
||||
|
||||
// Build JSON-RPC request
|
||||
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
|
||||
let json_rpc_request = self.build_tool_call_request(tool_name, messages)?;
|
||||
|
||||
// Generate span ID for this MCP tool call (child of filter span)
|
||||
let mcp_span_id = generate_random_span_id();
|
||||
|
||||
// Build headers
|
||||
let agent_headers =
|
||||
self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id), trace_id.clone(), mcp_span_id.clone())?;
|
||||
|
||||
// Send request with tracing
|
||||
let start_time = SystemTime::now();
|
||||
let start_instant = Instant::now();
|
||||
|
||||
let response = self
|
||||
.send_mcp_request(
|
||||
&json_rpc_request,
|
||||
agent_headers,
|
||||
&agent.id,
|
||||
)
|
||||
.await?;
|
||||
let http_status = response.status();
|
||||
let response_bytes = response.bytes().await?;
|
||||
|
||||
// Check for HTTP errors and handle them appropriately
|
||||
if !status.is_success() {
|
||||
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
|
||||
let end_time = SystemTime::now();
|
||||
let elapsed = start_instant.elapsed();
|
||||
|
||||
if status.is_client_error() {
|
||||
// 4xx errors - cascade back to developer
|
||||
return Err(PipelineError::ClientError {
|
||||
agent: agent.id.clone(),
|
||||
status: status.as_u16(),
|
||||
body: error_body,
|
||||
});
|
||||
} else if status.is_server_error() {
|
||||
// 5xx errors - server/agent error
|
||||
return Err(PipelineError::ServerError {
|
||||
agent: agent.id.clone(),
|
||||
status: status.as_u16(),
|
||||
body: error_body,
|
||||
});
|
||||
}
|
||||
// Record MCP tool call span
|
||||
if let Some(collector) = trace_collector {
|
||||
let mut attrs = HashMap::new();
|
||||
attrs.insert("mcp.method", "tools/call".to_string());
|
||||
attrs.insert("mcp.tool_name", tool_name.to_string());
|
||||
attrs.insert("mcp.session_id", mcp_session_id.clone());
|
||||
attrs.insert("http.status_code", http_status.as_u16().to_string());
|
||||
|
||||
self.record_mcp_span(
|
||||
collector,
|
||||
"tool_call",
|
||||
&agent.id,
|
||||
start_time,
|
||||
end_time,
|
||||
elapsed,
|
||||
Some(attrs),
|
||||
trace_id.clone(),
|
||||
filter_span_id.clone(),
|
||||
Some(mcp_span_id),
|
||||
);
|
||||
}
|
||||
|
||||
// Parse the response as JSON to extract the content
|
||||
let response_json: serde_json::Value = serde_json::from_slice(&response_bytes)?;
|
||||
// Handle HTTP errors
|
||||
if !http_status.is_success() {
|
||||
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
|
||||
return Err(if http_status.is_client_error() {
|
||||
PipelineError::ClientError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_body,
|
||||
}
|
||||
} else {
|
||||
PipelineError::ServerError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_body,
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let content = response_json
|
||||
.get("choices")
|
||||
.and_then(|choices| choices.as_array())
|
||||
.and_then(|choices| choices.first())
|
||||
.and_then(|choice| choice.get("message"))
|
||||
.and_then(|message| message.get("content"))
|
||||
.and_then(|content| content.as_str())
|
||||
.ok_or_else(|| PipelineError::NoContentInResponse(agent.id.clone()))?
|
||||
info!(
|
||||
"Response from agent {}: {}",
|
||||
agent.id,
|
||||
String::from_utf8_lossy(&response_bytes)
|
||||
);
|
||||
|
||||
// Parse SSE response
|
||||
let data_chunk = self.parse_sse_response(&response_bytes, &agent.id)?;
|
||||
let response: JsonRpcResponse = serde_json::from_str(&data_chunk)?;
|
||||
let response_result = response
|
||||
.result
|
||||
.ok_or_else(|| PipelineError::NoResultInResponse(agent.id.clone()))?;
|
||||
|
||||
// Check if error field is set in response result
|
||||
if response_result
|
||||
.get("isError")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let error_message = response_result
|
||||
.get("content")
|
||||
.and_then(|v| v.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|v| v.get("text"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown_error")
|
||||
.to_string();
|
||||
|
||||
return Err(PipelineError::ClientError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_message,
|
||||
});
|
||||
}
|
||||
|
||||
// Extract structured content and parse messages
|
||||
let response_json = response_result
|
||||
.get("structuredContent")
|
||||
.ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?;
|
||||
|
||||
let messages: Vec<Message> = response_json
|
||||
.get("result")
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or_else(|| PipelineError::NoMessagesInResponse(agent.id.clone()))?
|
||||
.iter()
|
||||
.map(|msg_value| serde_json::from_value(msg_value.clone()))
|
||||
.collect::<Result<Vec<Message>, _>>()
|
||||
.map_err(PipelineError::ParseError)?;
|
||||
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
/// Build an initialize JSON-RPC request
|
||||
fn build_initialize_request(&self) -> JsonRpcRequest {
|
||||
JsonRpcRequest {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
id: JsonRpcId::String(Uuid::new_v4().to_string()),
|
||||
method: MCP_INITIALIZE.to_string(),
|
||||
params: Some({
|
||||
let mut params = HashMap::new();
|
||||
params.insert(
|
||||
"protocolVersion".to_string(),
|
||||
serde_json::Value::String("2024-11-05".to_string()),
|
||||
);
|
||||
params.insert("capabilities".to_string(), serde_json::json!({}));
|
||||
params.insert(
|
||||
"clientInfo".to_string(),
|
||||
serde_json::json!({
|
||||
"name": BRIGHT_STAFF_SERVICE_NAME,
|
||||
"version": "1.0.0"
|
||||
}),
|
||||
);
|
||||
params
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send initialized notification after session creation
|
||||
async fn send_initialized_notification(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
session_id: &str,
|
||||
trace_id: String,
|
||||
parent_span_id: String,
|
||||
) -> Result<(), PipelineError> {
|
||||
let initialized_notification = JsonRpcNotification {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
method: MCP_INITIALIZE_NOTIFICATION.to_string(),
|
||||
params: None,
|
||||
};
|
||||
|
||||
let notification_body = serde_json::to_string(&initialized_notification)?;
|
||||
debug!("Sending initialized notification for agent {}", agent_id);
|
||||
|
||||
let headers = self.build_mcp_headers(&HeaderMap::new(), agent_id, Some(session_id), trace_id.clone(), parent_span_id.clone())?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/mcp", self.url))
|
||||
.headers(headers)
|
||||
.body(notification_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
info!(
|
||||
"Initialized notification response status: {}",
|
||||
response.status()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_new_session_id(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
trace_id: String,
|
||||
parent_span_id: String,
|
||||
) -> String {
|
||||
info!("Initializing MCP session for agent {}", agent_id);
|
||||
|
||||
let initialize_request = self.build_initialize_request();
|
||||
let headers = self
|
||||
.build_mcp_headers(&HeaderMap::new(), agent_id, None, trace_id.clone(), parent_span_id.clone())
|
||||
.expect("Failed to build headers for initialization");
|
||||
|
||||
let response = self
|
||||
.send_mcp_request(&initialize_request, headers, agent_id)
|
||||
.await
|
||||
.expect("Failed to initialize MCP session");
|
||||
|
||||
info!("Initialize response status: {}", response.status());
|
||||
|
||||
let session_id = response
|
||||
.headers()
|
||||
.get("mcp-session-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.expect("No mcp-session-id in response")
|
||||
.to_string();
|
||||
|
||||
Ok(content)
|
||||
info!(
|
||||
"Created new MCP session for agent {}: {}",
|
||||
agent_id, session_id
|
||||
);
|
||||
|
||||
// Send initialized notification
|
||||
self.send_initialized_notification(
|
||||
agent_id,
|
||||
&session_id,
|
||||
trace_id.clone(),
|
||||
parent_span_id.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to send initialized notification");
|
||||
|
||||
session_id
|
||||
}
|
||||
|
||||
/// Send request to terminal agent and return the raw response for streaming
|
||||
pub async fn invoke_upstream_agent(
|
||||
pub async fn invoke_agent(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
original_request: &ChatCompletionsRequest,
|
||||
mut original_request: ProviderRequestType,
|
||||
terminal_agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
trace_id: String,
|
||||
agent_span_id: String,
|
||||
) -> Result<reqwest::Response, PipelineError> {
|
||||
let mut request = original_request.clone();
|
||||
request.messages = messages.to_vec();
|
||||
// let mut request = original_request.clone();
|
||||
original_request.set_messages(messages);
|
||||
|
||||
let request_body = serde_json::to_string(&request)?;
|
||||
let request_body = ProviderRequestType::to_bytes(&original_request).unwrap();
|
||||
// let request_body = serde_json::to_string(&request)?;
|
||||
debug!("Sending request to terminal agent {}", terminal_agent.id);
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
// Set traceparent header to make the egress span a child of the agent span
|
||||
if !trace_id.is_empty() && !agent_span_id.is_empty() {
|
||||
let trace_parent = format!("00-{}-{}-01", trace_id, agent_span_id);
|
||||
agent_headers.remove(TRACE_PARENT_HEADER);
|
||||
agent_headers.insert(
|
||||
TRACE_PARENT_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&trace_parent).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&terminal_agent.id)
|
||||
|
|
@ -203,7 +704,7 @@ impl PipelineProcessor {
|
|||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.url)
|
||||
.post(format!("{}/v1/chat/completions", self.url))
|
||||
.headers(agent_headers)
|
||||
.body(request_body)
|
||||
.send()
|
||||
|
|
@ -217,6 +718,7 @@ impl PipelineProcessor {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use hermesllm::apis::openai::{Message, MessageContent, Role};
|
||||
use mockito::Server;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn create_test_message(role: Role, content: &str) -> Message {
|
||||
|
|
@ -240,23 +742,149 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_agent_not_found_error() {
|
||||
let processor = PipelineProcessor::default();
|
||||
let mut processor = PipelineProcessor::default();
|
||||
let agent_map = HashMap::new();
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let initial_request = ChatCompletionsRequest {
|
||||
messages: vec![create_test_message(Role::User, "Hello")],
|
||||
model: "test-model".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
let messages = vec![create_test_message(Role::User, "Hello")];
|
||||
|
||||
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
|
||||
|
||||
let result = processor
|
||||
.process_filter_chain(&initial_request, &pipeline, &agent_map, &request_headers)
|
||||
.process_filter_chain(&messages, &pipeline, &agent_map, &request_headers, None, String::new(), String::new())
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
matches!(result.unwrap_err(), PipelineError::AgentNotFound(_));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_filter_http_status_error() {
|
||||
let mut server = Server::new_async().await;
|
||||
let _m = server
|
||||
.mock("POST", "/mcp")
|
||||
.with_status(500)
|
||||
.with_body("boom")
|
||||
.create();
|
||||
|
||||
let server_url = server.url();
|
||||
let mut processor = PipelineProcessor::new(server_url.clone());
|
||||
processor
|
||||
.agent_id_session_map
|
||||
.insert("agent-1".to_string(), "session-1".to_string());
|
||||
|
||||
let agent = Agent {
|
||||
id: "agent-1".to_string(),
|
||||
transport: None,
|
||||
tool: None,
|
||||
url: server_url,
|
||||
agent_type: None,
|
||||
};
|
||||
|
||||
let messages = vec![create_test_message(Role::User, "Hello")];
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_filter(&messages, &agent, &request_headers, None, "trace-123".to_string(), "span-123".to_string())
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(PipelineError::ServerError { status, body, .. }) => {
|
||||
assert_eq!(status, 500);
|
||||
assert_eq!(body, "boom");
|
||||
}
|
||||
_ => panic!("Expected server error for 500 status"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_filter_http_client_error() {
|
||||
let mut server = Server::new_async().await;
|
||||
let _m = server
|
||||
.mock("POST", "/mcp")
|
||||
.with_status(400)
|
||||
.with_body("bad request")
|
||||
.create();
|
||||
|
||||
let server_url = server.url();
|
||||
let mut processor = PipelineProcessor::new(server_url.clone());
|
||||
processor
|
||||
.agent_id_session_map
|
||||
.insert("agent-3".to_string(), "session-3".to_string());
|
||||
|
||||
let agent = Agent {
|
||||
id: "agent-3".to_string(),
|
||||
transport: None,
|
||||
tool: None,
|
||||
url: server_url,
|
||||
agent_type: None,
|
||||
};
|
||||
|
||||
let messages = vec![create_test_message(Role::User, "Ping")];
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_filter(&messages, &agent, &request_headers, None, "trace-456".to_string(), "span-456".to_string())
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(PipelineError::ClientError { status, body, .. }) => {
|
||||
assert_eq!(status, 400);
|
||||
assert_eq!(body, "bad request");
|
||||
}
|
||||
_ => panic!("Expected client error for 400 status"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_filter_mcp_error_flag() {
|
||||
let rpc_body = serde_json::json!({
|
||||
"jsonrpc": JSON_RPC_VERSION,
|
||||
"id": "1",
|
||||
"result": {
|
||||
"isError": true,
|
||||
"content": [
|
||||
{ "text": "bad tool call" }
|
||||
]
|
||||
}
|
||||
});
|
||||
|
||||
let sse_body = format!("event: message\ndata: {}\n\n", rpc_body.to_string());
|
||||
|
||||
let mut server = Server::new_async().await;
|
||||
let _m = server
|
||||
.mock("POST", "/mcp")
|
||||
.with_status(200)
|
||||
.with_body(sse_body)
|
||||
.create();
|
||||
|
||||
let server_url = server.url();
|
||||
let mut processor = PipelineProcessor::new(server_url.clone());
|
||||
processor
|
||||
.agent_id_session_map
|
||||
.insert("agent-2".to_string(), "session-2".to_string());
|
||||
|
||||
let agent = Agent {
|
||||
id: "agent-2".to_string(),
|
||||
transport: None,
|
||||
tool: None,
|
||||
url: server_url,
|
||||
agent_type: None,
|
||||
};
|
||||
|
||||
let messages = vec![create_test_message(Role::User, "Hi")];
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_filter(&messages, &agent, &request_headers, None, "trace-789".to_string(), "span-789".to_string())
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(PipelineError::ClientError { status, body, .. }) => {
|
||||
assert_eq!(status, 200);
|
||||
assert_eq!(body, "bad tool call");
|
||||
}
|
||||
_ => panic!("Expected client error when isError flag is set"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
use brightstaff::handlers::agent_chat_completions::agent_chat;
|
||||
use brightstaff::handlers::function_calling::function_calling_chat_handler;
|
||||
use brightstaff::handlers::llm::llm_chat;
|
||||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::handlers::function_calling::{function_calling_chat_handler};
|
||||
use brightstaff::router::llm_router::RouterService;
|
||||
use brightstaff::state::StateStorage;
|
||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||
use brightstaff::state::memory::MemoryConversationalStorage;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::Configuration;
|
||||
use common::configuration::{Agent, Configuration};
|
||||
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
|
||||
use common::traces::TraceCollector;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
|
|
@ -63,8 +63,18 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let arch_config = Arc::new(config);
|
||||
|
||||
// combine agents and filters into a single list of agents
|
||||
let all_agents: Vec<Agent> = arch_config
|
||||
.agents
|
||||
.as_deref()
|
||||
.unwrap_or_default()
|
||||
.iter()
|
||||
.chain(arch_config.filters.as_deref().unwrap_or_default())
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let llm_providers = Arc::new(RwLock::new(arch_config.model_providers.clone()));
|
||||
let agents_list = Arc::new(RwLock::new(arch_config.agents.clone()));
|
||||
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
|
||||
let listeners = Arc::new(RwLock::new(arch_config.listeners.clone()));
|
||||
let llm_provider_url =
|
||||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
|
|
@ -98,7 +108,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
info!("Tracing configuration found in arch_config.yaml");
|
||||
Some(true)
|
||||
} else {
|
||||
info!("No tracing configuration in arch_config.yaml, will check OTEL_TRACING_ENABLED env var");
|
||||
info!(
|
||||
"No tracing configuration in arch_config.yaml, will check OTEL_TRACING_ENABLED env var"
|
||||
);
|
||||
None
|
||||
};
|
||||
let trace_collector = Arc::new(TraceCollector::new(tracing_enabled));
|
||||
|
|
@ -142,11 +154,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let io = TokioIo::new(stream);
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
||||
let model_aliases: Arc<Option<std::collections::HashMap<String, common::configuration::ModelAlias>>> = Arc::clone(&model_aliases);
|
||||
let model_aliases: Arc<
|
||||
Option<std::collections::HashMap<String, common::configuration::ModelAlias>>,
|
||||
> = Arc::clone(&model_aliases);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
|
||||
let llm_providers = llm_providers.clone();
|
||||
let agents_list = agents_list.clone();
|
||||
let agents_list = combined_agents_filters_list.clone();
|
||||
let listeners = listeners.clone();
|
||||
let trace_collector = trace_collector.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
|
|
@ -162,28 +176,36 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let state_storage = state_storage.clone();
|
||||
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, req.uri().path());
|
||||
llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector, state_storage)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::POST, "/agents/v1/chat/completions") => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, req.uri().path());
|
||||
agent_chat(
|
||||
let path = req.uri().path();
|
||||
// Check if path starts with /agents
|
||||
if path.starts_with("/agents") {
|
||||
// Check if it matches one of the agent API paths
|
||||
let stripped_path = path.strip_prefix("/agents").unwrap();
|
||||
if matches!(
|
||||
stripped_path,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
||||
) {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, stripped_path);
|
||||
return agent_chat(
|
||||
req,
|
||||
router_service,
|
||||
fully_qualified_url,
|
||||
agents_list,
|
||||
listeners,
|
||||
trace_collector,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
.await;
|
||||
}
|
||||
}
|
||||
match (req.method(), path) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, path);
|
||||
llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector, state_storage)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
|
||||
(&Method::POST, "/function_calling") => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, "/v1/chat/completions");
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ pub mod operation_component {
|
|||
pub const HANDOFF: &str = "plano(handoff)";
|
||||
|
||||
/// Agent filter execution
|
||||
pub const AGENT_FILTER: &str = "plano(agent filter)";
|
||||
pub const AGENT_FILTER: &str = "plano(filter)";
|
||||
|
||||
/// Agent execution
|
||||
pub const AGENT: &str = "plano(agent)";
|
||||
|
|
@ -203,6 +203,7 @@ pub mod operation_component {
|
|||
pub struct OperationNameBuilder {
|
||||
method: Option<String>,
|
||||
path: Option<String>,
|
||||
operation: Option<String>,
|
||||
target: Option<String>,
|
||||
}
|
||||
|
||||
|
|
@ -212,6 +213,7 @@ impl OperationNameBuilder {
|
|||
Self {
|
||||
method: None,
|
||||
path: None,
|
||||
operation: None,
|
||||
target: None,
|
||||
}
|
||||
}
|
||||
|
|
@ -234,6 +236,15 @@ impl OperationNameBuilder {
|
|||
self
|
||||
}
|
||||
|
||||
/// Set the operation type (optional, for MCP operations)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `operation` - Operation type (e.g., "tool_call", "session_init", "notification")
|
||||
pub fn with_operation(mut self, operation: impl Into<String>) -> Self {
|
||||
self.operation = Some(operation.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the target (model name, agent name, or filter name)
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
@ -246,7 +257,8 @@ impl OperationNameBuilder {
|
|||
/// Build the operation name string
|
||||
///
|
||||
/// # Format
|
||||
/// - With all components: `{method} {path} {target}`
|
||||
/// - With all components: `{method} {path} ({operation}) {target}`
|
||||
/// - Without operation: `{method} {path} {target}`
|
||||
/// - Without target: `{method} {path}`
|
||||
/// - Without path: `{method}`
|
||||
/// - Empty: returns empty string
|
||||
|
|
@ -258,7 +270,11 @@ impl OperationNameBuilder {
|
|||
}
|
||||
|
||||
if let Some(path) = self.path {
|
||||
parts.push(path);
|
||||
if let Some(operation) = self.operation {
|
||||
parts.push(format!("{} ({})", path, operation));
|
||||
} else {
|
||||
parts.push(path);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(target) = self.target {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue