mirror of
https://github.com/katanemo/plano.git
synced 2026-05-07 23:02:43 +02:00
plano orchestration using plano orchestration 4b model (#637)
This commit is contained in:
parent
60162e0575
commit
15fbb6c3af
40 changed files with 4054 additions and 449 deletions
|
|
@ -17,7 +17,7 @@ use tracing::{debug, info, warn};
|
|||
use super::agent_selector::{AgentSelectionError, AgentSelector};
|
||||
use super::pipeline_processor::{PipelineError, PipelineProcessor};
|
||||
use super::response_handler::ResponseHandler;
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::router::plano_orchestrator::OrchestratorService;
|
||||
use crate::tracing::{OperationNameBuilder, operation_component, http};
|
||||
|
||||
/// Main errors for agent chat completions
|
||||
|
|
@ -37,7 +37,7 @@ pub enum AgentFilterChainError {
|
|||
|
||||
pub async fn agent_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
_: String,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
|
|
@ -45,7 +45,7 @@ pub async fn agent_chat(
|
|||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
match handle_agent_chat(
|
||||
request,
|
||||
router_service,
|
||||
orchestrator_service,
|
||||
agents_list,
|
||||
listeners,
|
||||
trace_collector,
|
||||
|
|
@ -123,13 +123,13 @@ pub async fn agent_chat(
|
|||
|
||||
async fn handle_agent_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
trace_collector: Arc<common::traces::TraceCollector>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
|
||||
// Initialize services
|
||||
let agent_selector = AgentSelector::new(router_service);
|
||||
let agent_selector = AgentSelector::new(orchestrator_service);
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let response_handler = ResponseHandler::new();
|
||||
|
||||
|
|
@ -186,15 +186,6 @@ async fn handle_agent_chat(
|
|||
|
||||
let message: Vec<OpenAIMessage> = client_request.get_messages();
|
||||
|
||||
// let chat_completions_request: ChatCompletionsRequest =
|
||||
// serde_json::from_slice(&chat_request_bytes).map_err(|err| {
|
||||
// warn!(
|
||||
// "Failed to parse request body as ChatCompletionsRequest: {}",
|
||||
// err
|
||||
// );
|
||||
// AgentFilterChainError::RequestParsing(err)
|
||||
// })?;
|
||||
|
||||
// Extract trace parent for routing
|
||||
let trace_parent = request_headers
|
||||
.iter()
|
||||
|
|
@ -215,94 +206,166 @@ async fn handle_agent_chat(
|
|||
(String::new(), None)
|
||||
};
|
||||
|
||||
// Select appropriate agent using arch router llm model
|
||||
let selected_agent = agent_selector
|
||||
.select_agent(&message, &listener, trace_parent.clone())
|
||||
// Select appropriate agents using arch orchestrator llm model
|
||||
let selection_span_id = generate_random_span_id();
|
||||
let selection_start_time = SystemTime::now();
|
||||
let selection_start_instant = Instant::now();
|
||||
|
||||
let selected_agents = agent_selector
|
||||
.select_agents(&message, &listener, trace_parent.clone())
|
||||
.await?;
|
||||
|
||||
debug!("Processing agent pipeline: {}", selected_agent.id);
|
||||
|
||||
// Record the start time for agent span
|
||||
let agent_start_time = SystemTime::now();
|
||||
let agent_start_instant = Instant::now();
|
||||
// let (span_id, trace_id) = trace_collector.start_span(
|
||||
// trace_parent.clone(),
|
||||
// operation_component::AGENT,
|
||||
// &format!("/agents{}", request_path),
|
||||
// &selected_agent.id,
|
||||
// );
|
||||
|
||||
let span_id = generate_random_span_id();
|
||||
|
||||
// Process the filter chain
|
||||
let chat_history = pipeline_processor
|
||||
.process_filter_chain(
|
||||
&message,
|
||||
&selected_agent,
|
||||
&agent_map,
|
||||
&request_headers,
|
||||
Some(&trace_collector),
|
||||
trace_id.clone(),
|
||||
span_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Get terminal agent and send final response
|
||||
let terminal_agent_name = selected_agent.id.clone();
|
||||
let terminal_agent = agent_map.get(&terminal_agent_name).unwrap();
|
||||
|
||||
debug!("Processing terminal agent: {}", terminal_agent_name);
|
||||
debug!("Terminal agent details: {:?}", terminal_agent);
|
||||
|
||||
let llm_response = pipeline_processor
|
||||
.invoke_agent(
|
||||
&chat_history,
|
||||
client_request,
|
||||
terminal_agent,
|
||||
&request_headers,
|
||||
trace_id.clone(),
|
||||
span_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Record agent span after processing is complete
|
||||
let agent_end_time = SystemTime::now();
|
||||
let agent_elapsed = agent_start_instant.elapsed();
|
||||
|
||||
// Build full path with /agents prefix
|
||||
let full_path = format!("/agents{}", request_path);
|
||||
|
||||
// Build operation name: POST {full_path} {agent_name}
|
||||
let operation_name = OperationNameBuilder::new()
|
||||
// Record agent selection span
|
||||
let selection_end_time = SystemTime::now();
|
||||
let selection_elapsed = selection_start_instant.elapsed();
|
||||
let selection_operation_name = OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path(&full_path)
|
||||
.with_target(&terminal_agent_name)
|
||||
.with_path("/agents/select")
|
||||
.with_target(&listener.name)
|
||||
.build();
|
||||
|
||||
let mut span_builder = SpanBuilder::new(&operation_name)
|
||||
.with_span_id(span_id)
|
||||
let mut selection_span_builder = SpanBuilder::new(&selection_operation_name)
|
||||
.with_span_id(selection_span_id)
|
||||
.with_kind(SpanKind::Internal)
|
||||
.with_start_time(agent_start_time)
|
||||
.with_end_time(agent_end_time)
|
||||
.with_start_time(selection_start_time)
|
||||
.with_end_time(selection_end_time)
|
||||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::TARGET, full_path)
|
||||
.with_attribute("agent.name", terminal_agent_name.clone())
|
||||
.with_attribute("duration_ms", format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0));
|
||||
.with_attribute(http::TARGET, "/agents/select")
|
||||
.with_attribute("selection.listener", listener.name.clone())
|
||||
.with_attribute("selection.agent_count", selected_agents.len().to_string())
|
||||
.with_attribute("selection.agents", selected_agents.iter().map(|a| a.id.as_str()).collect::<Vec<_>>().join(","))
|
||||
.with_attribute("duration_ms", format!("{:.2}", selection_elapsed.as_secs_f64() * 1000.0));
|
||||
|
||||
if !trace_id.is_empty() {
|
||||
span_builder = span_builder.with_trace_id(trace_id);
|
||||
selection_span_builder = selection_span_builder.with_trace_id(trace_id.clone());
|
||||
}
|
||||
if let Some(parent_id) = parent_span_id {
|
||||
span_builder = span_builder.with_parent_span_id(parent_id);
|
||||
if let Some(parent_id) = parent_span_id.clone() {
|
||||
selection_span_builder = selection_span_builder.with_parent_span_id(parent_id);
|
||||
}
|
||||
|
||||
let span = span_builder.build();
|
||||
// Use plano(agent) as service name for the agent processing span
|
||||
trace_collector.record_span(operation_component::AGENT, span);
|
||||
let selection_span = selection_span_builder.build();
|
||||
trace_collector.record_span(operation_component::ORCHESTRATOR, selection_span);
|
||||
|
||||
// Create streaming response
|
||||
response_handler
|
||||
.create_streaming_response(llm_response)
|
||||
.await
|
||||
.map_err(AgentFilterChainError::from)
|
||||
info!("Selected {} agent(s) for execution", selected_agents.len());
|
||||
|
||||
// Execute agents sequentially, passing output from one to the next
|
||||
let mut current_messages = message.clone();
|
||||
let agent_count = selected_agents.len();
|
||||
|
||||
for (agent_index, selected_agent) in selected_agents.iter().enumerate() {
|
||||
let is_last_agent = agent_index == agent_count - 1;
|
||||
|
||||
debug!(
|
||||
"Processing agent {}/{}: {}",
|
||||
agent_index + 1,
|
||||
agent_count,
|
||||
selected_agent.id
|
||||
);
|
||||
|
||||
// Record the start time for agent span
|
||||
let agent_start_time = SystemTime::now();
|
||||
let agent_start_instant = Instant::now();
|
||||
let span_id = generate_random_span_id();
|
||||
|
||||
// Get agent name
|
||||
let agent_name = selected_agent.id.clone();
|
||||
|
||||
// Process the filter chain
|
||||
let chat_history = pipeline_processor
|
||||
.process_filter_chain(
|
||||
¤t_messages,
|
||||
selected_agent,
|
||||
&agent_map,
|
||||
&request_headers,
|
||||
Some(&trace_collector),
|
||||
trace_id.clone(),
|
||||
span_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Get agent details and invoke
|
||||
let agent = agent_map.get(&agent_name).unwrap();
|
||||
|
||||
debug!("Invoking agent: {}", agent_name);
|
||||
|
||||
let llm_response = pipeline_processor
|
||||
.invoke_agent(
|
||||
&chat_history,
|
||||
client_request.clone(),
|
||||
agent,
|
||||
&request_headers,
|
||||
trace_id.clone(),
|
||||
span_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Record agent span
|
||||
let agent_end_time = SystemTime::now();
|
||||
let agent_elapsed = agent_start_instant.elapsed();
|
||||
let full_path = format!("/agents{}", request_path);
|
||||
let operation_name = OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path(&full_path)
|
||||
.with_target(&agent_name)
|
||||
.build();
|
||||
|
||||
let mut span_builder = SpanBuilder::new(&operation_name)
|
||||
.with_span_id(span_id)
|
||||
.with_kind(SpanKind::Internal)
|
||||
.with_start_time(agent_start_time)
|
||||
.with_end_time(agent_end_time)
|
||||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::TARGET, full_path)
|
||||
.with_attribute("agent.name", agent_name.clone())
|
||||
.with_attribute("agent.sequence", format!("{}/{}", agent_index + 1, agent_count))
|
||||
.with_attribute("duration_ms", format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0));
|
||||
|
||||
if !trace_id.is_empty() {
|
||||
span_builder = span_builder.with_trace_id(trace_id.clone());
|
||||
}
|
||||
if let Some(parent_id) = parent_span_id.clone() {
|
||||
span_builder = span_builder.with_parent_span_id(parent_id);
|
||||
}
|
||||
|
||||
let span = span_builder.build();
|
||||
trace_collector.record_span(operation_component::AGENT, span);
|
||||
|
||||
// If this is the last agent, return the streaming response
|
||||
if is_last_agent {
|
||||
info!("Completed agent chain, returning response from last agent: {}", agent_name);
|
||||
return response_handler
|
||||
.create_streaming_response(llm_response)
|
||||
.await
|
||||
.map_err(AgentFilterChainError::from);
|
||||
}
|
||||
|
||||
// For intermediate agents, collect the full response and pass to next agent
|
||||
debug!("Collecting response from intermediate agent: {}", agent_name);
|
||||
let response_text = response_handler.collect_full_response(llm_response).await?;
|
||||
|
||||
info!(
|
||||
"Agent {} completed, passing {} character response to next agent",
|
||||
agent_name,
|
||||
response_text.len()
|
||||
);
|
||||
|
||||
// remove last message and add new one at the end
|
||||
let last_message = current_messages.pop().unwrap();
|
||||
|
||||
// Create a new message with the agent's response as assistant message
|
||||
// and add it to the conversation history
|
||||
current_messages.push(OpenAIMessage {
|
||||
role: hermesllm::apis::openai::Role::Assistant,
|
||||
content: hermesllm::apis::openai::MessageContent::Text(response_text),
|
||||
name: Some(agent_name.clone()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
|
||||
current_messages.push(last_message);
|
||||
|
||||
}
|
||||
|
||||
// This should never be reached since we return in the last agent iteration
|
||||
unreachable!("Agent execution loop should have returned a response")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,12 +2,12 @@ use std::collections::HashMap;
|
|||
use std::sync::Arc;
|
||||
|
||||
use common::configuration::{
|
||||
Agent, AgentFilterChain, Listener, ModelUsagePreference, RoutingPreference,
|
||||
Agent, AgentFilterChain, Listener, AgentUsagePreference, OrchestrationPreference,
|
||||
};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::router::plano_orchestrator::OrchestratorService;
|
||||
|
||||
/// Errors that can occur during agent selection
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
|
@ -16,23 +16,23 @@ pub enum AgentSelectionError {
|
|||
ListenerNotFound(String),
|
||||
#[error("No agents configured for listener: {0}")]
|
||||
NoAgentsConfigured(String),
|
||||
#[error("Routing service error: {0}")]
|
||||
RoutingError(String),
|
||||
#[error("Default agent not found for listener: {0}")]
|
||||
DefaultAgentNotFound(String),
|
||||
#[error("MCP client error: {0}")]
|
||||
McpError(String),
|
||||
#[error("Orchestration service error: {0}")]
|
||||
OrchestrationError(String),
|
||||
}
|
||||
|
||||
/// Service for selecting agents based on routing preferences and listener configuration
|
||||
/// Service for selecting agents based on orchestration preferences and listener configuration
|
||||
pub struct AgentSelector {
|
||||
router_service: Arc<RouterService>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
}
|
||||
|
||||
impl AgentSelector {
|
||||
pub fn new(router_service: Arc<RouterService>) -> Self {
|
||||
pub fn new(orchestrator_service: Arc<OrchestratorService>) -> Self {
|
||||
Self {
|
||||
router_service,
|
||||
orchestrator_service,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -63,59 +63,6 @@ impl AgentSelector {
|
|||
.collect()
|
||||
}
|
||||
|
||||
/// Select appropriate agent based on routing preferences
|
||||
pub async fn select_agent(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
listener: &Listener,
|
||||
trace_parent: Option<String>,
|
||||
) -> Result<AgentFilterChain, AgentSelectionError> {
|
||||
let agents = listener
|
||||
.agents
|
||||
.as_ref()
|
||||
.ok_or_else(|| AgentSelectionError::NoAgentsConfigured(listener.name.clone()))?;
|
||||
|
||||
// If only one agent, skip routing
|
||||
if agents.len() == 1 {
|
||||
debug!("Only one agent available, skipping routing");
|
||||
return Ok(agents[0].clone());
|
||||
}
|
||||
|
||||
let usage_preferences = self
|
||||
.convert_agent_description_to_routing_preferences(agents)
|
||||
.await;
|
||||
debug!(
|
||||
"Agents usage preferences for agent routing str: {}",
|
||||
serde_json::to_string(&usage_preferences).unwrap_or_default()
|
||||
);
|
||||
|
||||
match self
|
||||
.router_service
|
||||
.determine_route(messages, trace_parent, Some(usage_preferences))
|
||||
.await
|
||||
{
|
||||
Ok(Some((_, agent_name))) => {
|
||||
debug!("Determined agent: {}", agent_name);
|
||||
let selected_agent = agents
|
||||
.iter()
|
||||
.find(|a| a.id == agent_name)
|
||||
.cloned()
|
||||
.ok_or_else(|| {
|
||||
AgentSelectionError::RoutingError(format!(
|
||||
"Selected agent '{}' not found in listener agents",
|
||||
agent_name
|
||||
))
|
||||
})?;
|
||||
Ok(selected_agent)
|
||||
}
|
||||
Ok(None) => {
|
||||
debug!("No agent determined using routing preferences, using default agent");
|
||||
self.get_default_agent(agents, &listener.name)
|
||||
}
|
||||
Err(err) => Err(AgentSelectionError::RoutingError(err.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the default agent or the first agent if no default is specified
|
||||
fn get_default_agent(
|
||||
&self,
|
||||
|
|
@ -136,17 +83,17 @@ impl AgentSelector {
|
|||
.ok_or_else(|| AgentSelectionError::DefaultAgentNotFound(listener_name.to_string()))
|
||||
}
|
||||
|
||||
/// Convert agent descriptions to routing preferences
|
||||
async fn convert_agent_description_to_routing_preferences(
|
||||
/// Convert agent descriptions to orchestration preferences
|
||||
async fn convert_agent_description_to_orchestration_preferences(
|
||||
&self,
|
||||
agents: &[AgentFilterChain],
|
||||
) -> Vec<ModelUsagePreference> {
|
||||
) -> Vec<AgentUsagePreference> {
|
||||
let mut preferences = Vec::new();
|
||||
|
||||
for agent_chain in agents {
|
||||
preferences.push(ModelUsagePreference {
|
||||
preferences.push(AgentUsagePreference {
|
||||
model: agent_chain.id.clone(),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
orchestration_preferences: vec![OrchestrationPreference {
|
||||
name: agent_chain.id.clone(),
|
||||
description: agent_chain.description.clone().unwrap_or_default(),
|
||||
}],
|
||||
|
|
@ -155,6 +102,71 @@ impl AgentSelector {
|
|||
|
||||
preferences
|
||||
}
|
||||
|
||||
/// Select multiple agents using orchestration
|
||||
pub async fn select_agents(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
listener: &Listener,
|
||||
trace_parent: Option<String>,
|
||||
) -> Result<Vec<AgentFilterChain>, AgentSelectionError> {
|
||||
let agents = listener
|
||||
.agents
|
||||
.as_ref()
|
||||
.ok_or_else(|| AgentSelectionError::NoAgentsConfigured(listener.name.clone()))?;
|
||||
|
||||
// If only one agent, skip orchestration
|
||||
if agents.len() == 1 {
|
||||
debug!("Only one agent available, skipping orchestration");
|
||||
return Ok(vec![agents[0].clone()]);
|
||||
}
|
||||
|
||||
let usage_preferences = self
|
||||
.convert_agent_description_to_orchestration_preferences(agents)
|
||||
.await;
|
||||
debug!(
|
||||
"Agents usage preferences for orchestration: {}",
|
||||
serde_json::to_string(&usage_preferences).unwrap_or_default()
|
||||
);
|
||||
|
||||
match self
|
||||
.orchestrator_service
|
||||
.determine_orchestration(messages, trace_parent, Some(usage_preferences))
|
||||
.await
|
||||
{
|
||||
Ok(Some(routes)) => {
|
||||
debug!("Determined {} agent(s) via orchestration", routes.len());
|
||||
let mut selected_agents = Vec::new();
|
||||
|
||||
for (route_name, agent_name) in routes {
|
||||
debug!("Processing route: {}, agent: {}", route_name, agent_name);
|
||||
let selected_agent = agents
|
||||
.iter()
|
||||
.find(|a| a.id == agent_name)
|
||||
.cloned()
|
||||
.ok_or_else(|| {
|
||||
AgentSelectionError::OrchestrationError(format!(
|
||||
"Selected agent '{}' not found in listener agents",
|
||||
agent_name
|
||||
))
|
||||
})?;
|
||||
selected_agents.push(selected_agent);
|
||||
}
|
||||
|
||||
if selected_agents.is_empty() {
|
||||
debug!("No agents determined using orchestration, using default agent");
|
||||
Ok(vec![self.get_default_agent(agents, &listener.name)?])
|
||||
} else {
|
||||
Ok(selected_agents)
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
debug!("No agents determined using orchestration, using default agent");
|
||||
Ok(vec![self.get_default_agent(agents, &listener.name)?])
|
||||
}
|
||||
Err(err) => Err(AgentSelectionError::OrchestrationError(err.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -162,12 +174,10 @@ mod tests {
|
|||
use super::*;
|
||||
use common::configuration::{AgentFilterChain, Listener};
|
||||
|
||||
fn create_test_router_service() -> Arc<RouterService> {
|
||||
Arc::new(RouterService::new(
|
||||
vec![], // empty providers for testing
|
||||
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
|
||||
Arc::new(OrchestratorService::new(
|
||||
"http://localhost:8080".to_string(),
|
||||
"test-model".to_string(),
|
||||
"test-provider".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
|
|
@ -176,7 +186,7 @@ mod tests {
|
|||
id: name.to_string(),
|
||||
description: Some(description.to_string()),
|
||||
default: Some(is_default),
|
||||
filter_chain: vec![name.to_string()],
|
||||
filter_chain: Some(vec![name.to_string()]),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -201,8 +211,8 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_find_listener_success() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
let orchestrator_service = create_test_orchestrator_service();
|
||||
let selector = AgentSelector::new(orchestrator_service);
|
||||
|
||||
let listener1 = create_test_listener("test-listener", vec![]);
|
||||
let listener2 = create_test_listener("other-listener", vec![]);
|
||||
|
|
@ -218,8 +228,8 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_find_listener_not_found() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
let orchestrator_service = create_test_orchestrator_service();
|
||||
let selector = AgentSelector::new(orchestrator_service);
|
||||
|
||||
let listeners = vec![create_test_listener("other-listener", vec![])];
|
||||
|
||||
|
|
@ -236,8 +246,8 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_create_agent_map() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
let orchestrator_service = create_test_orchestrator_service();
|
||||
let selector = AgentSelector::new(orchestrator_service);
|
||||
|
||||
let agents = vec![
|
||||
create_test_agent_struct("agent1"),
|
||||
|
|
@ -251,33 +261,10 @@ mod tests {
|
|||
assert!(agent_map.contains_key("agent2"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_convert_agent_description_to_routing_preferences() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
|
||||
let agents = vec![
|
||||
create_test_agent("agent1", "First agent description", true),
|
||||
create_test_agent("agent2", "Second agent description", false),
|
||||
];
|
||||
|
||||
let preferences = selector
|
||||
.convert_agent_description_to_routing_preferences(&agents)
|
||||
.await;
|
||||
|
||||
assert_eq!(preferences.len(), 2);
|
||||
assert_eq!(preferences[0].model, "agent1");
|
||||
assert_eq!(preferences[0].routing_preferences[0].name, "agent1");
|
||||
assert_eq!(
|
||||
preferences[0].routing_preferences[0].description,
|
||||
"First agent description"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_default_agent() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
let orchestrator_service = create_test_orchestrator_service();
|
||||
let selector = AgentSelector::new(orchestrator_service);
|
||||
|
||||
let agents = vec![
|
||||
create_test_agent("agent1", "First agent", false),
|
||||
|
|
@ -293,8 +280,8 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_get_default_agent_fallback_to_first() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
let orchestrator_service = create_test_orchestrator_service();
|
||||
let selector = AgentSelector::new(orchestrator_service);
|
||||
|
||||
let agents = vec![
|
||||
create_test_agent("agent1", "First agent", false),
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@ use hyper::header::HeaderMap;
|
|||
use crate::handlers::agent_selector::{AgentSelectionError, AgentSelector};
|
||||
use crate::handlers::pipeline_processor::PipelineProcessor;
|
||||
use crate::handlers::response_handler::ResponseHandler;
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::router::plano_orchestrator::OrchestratorService;
|
||||
|
||||
/// Integration test that demonstrates the modular agent chat flow
|
||||
/// This test shows how the three main components work together:
|
||||
/// 1. AgentSelector - selects the appropriate agent based on routing
|
||||
/// 1. AgentSelector - selects the appropriate agents based on orchestration
|
||||
/// 2. PipelineProcessor - executes the agent pipeline
|
||||
/// 3. ResponseHandler - handles response streaming
|
||||
#[cfg(test)]
|
||||
|
|
@ -18,12 +18,10 @@ mod integration_tests {
|
|||
use super::*;
|
||||
use common::configuration::{Agent, AgentFilterChain, Listener};
|
||||
|
||||
fn create_test_router_service() -> Arc<RouterService> {
|
||||
Arc::new(RouterService::new(
|
||||
vec![], // empty providers for testing
|
||||
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
|
||||
Arc::new(OrchestratorService::new(
|
||||
"http://localhost:8080".to_string(),
|
||||
"test-model".to_string(),
|
||||
"test-provider".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
|
|
@ -40,8 +38,8 @@ mod integration_tests {
|
|||
#[tokio::test]
|
||||
async fn test_modular_agent_chat_flow() {
|
||||
// Setup services
|
||||
let router_service = create_test_router_service();
|
||||
let agent_selector = AgentSelector::new(router_service);
|
||||
let orchestrator_service = create_test_orchestrator_service();
|
||||
let agent_selector = AgentSelector::new(orchestrator_service);
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
|
||||
// Create test data
|
||||
|
|
@ -64,7 +62,7 @@ mod integration_tests {
|
|||
|
||||
let agent_pipeline = AgentFilterChain {
|
||||
id: "terminal-agent".to_string(),
|
||||
filter_chain: vec!["filter-agent".to_string(), "terminal-agent".to_string()],
|
||||
filter_chain: Some(vec!["filter-agent".to_string(), "terminal-agent".to_string()]),
|
||||
description: Some("Test pipeline".to_string()),
|
||||
default: Some(true),
|
||||
};
|
||||
|
|
@ -104,7 +102,7 @@ mod integration_tests {
|
|||
// Create a pipeline with empty filter chain to avoid network calls
|
||||
let test_pipeline = AgentFilterChain {
|
||||
id: "terminal-agent".to_string(),
|
||||
filter_chain: vec![], // Empty filter chain - no network calls needed
|
||||
filter_chain: Some(vec![]), // Empty filter chain - no network calls needed
|
||||
description: None,
|
||||
default: None,
|
||||
};
|
||||
|
|
@ -143,7 +141,7 @@ mod integration_tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_error_handling_flow() {
|
||||
let router_service = create_test_router_service();
|
||||
let router_service = create_test_orchestrator_service();
|
||||
let agent_selector = AgentSelector::new(router_service);
|
||||
|
||||
// Test listener not found
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use common::configuration::{Agent, AgentFilterChain};
|
|||
use common::consts::{
|
||||
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
|
||||
};
|
||||
use common::traces::{SpanBuilder, SpanKind, generate_random_span_id};
|
||||
use common::traces::{generate_random_span_id, SpanBuilder, SpanKind};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use hyper::header::HeaderMap;
|
||||
|
|
@ -200,7 +200,13 @@ impl PipelineProcessor {
|
|||
) -> Result<Vec<Message>, PipelineError> {
|
||||
let mut chat_history_updated = chat_history.to_vec();
|
||||
|
||||
for agent_name in &agent_filter_chain.filter_chain {
|
||||
// If filter_chain is None or empty, proceed without filtering
|
||||
let filter_chain = match agent_filter_chain.filter_chain.as_ref() {
|
||||
Some(fc) if !fc.is_empty() => fc,
|
||||
_ => return Ok(chat_history_updated),
|
||||
};
|
||||
|
||||
for agent_name in filter_chain {
|
||||
debug!("Processing filter agent: {}", agent_name);
|
||||
|
||||
let agent = agent_map
|
||||
|
|
@ -210,10 +216,11 @@ impl PipelineProcessor {
|
|||
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
|
||||
|
||||
info!(
|
||||
"executing filter: {}/{}, url: {}, conversation length: {}",
|
||||
"executing filter: {}/{}, url: {}, type: {}, conversation length: {}",
|
||||
agent_name,
|
||||
tool_name,
|
||||
agent.url,
|
||||
agent.agent_type.as_deref().unwrap_or("mcp"),
|
||||
chat_history.len()
|
||||
);
|
||||
|
||||
|
|
@ -223,16 +230,29 @@ impl PipelineProcessor {
|
|||
// Generate filter span ID before execution so MCP spans can use it as parent
|
||||
let filter_span_id = generate_random_span_id();
|
||||
|
||||
chat_history_updated = self
|
||||
.execute_filter(
|
||||
&chat_history_updated,
|
||||
agent,
|
||||
request_headers,
|
||||
trace_collector,
|
||||
trace_id.clone(),
|
||||
filter_span_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
if agent.agent_type.as_deref().unwrap_or("mcp") == "mcp" {
|
||||
chat_history_updated = self
|
||||
.execute_mcp_filter(
|
||||
&chat_history_updated,
|
||||
agent,
|
||||
request_headers,
|
||||
trace_collector,
|
||||
trace_id.clone(),
|
||||
filter_span_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
chat_history_updated = self
|
||||
.execute_rest_filter(
|
||||
&chat_history_updated,
|
||||
agent,
|
||||
request_headers,
|
||||
trace_collector,
|
||||
trace_id.clone(),
|
||||
filter_span_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let end_time = SystemTime::now();
|
||||
let elapsed = start_instant.elapsed();
|
||||
|
|
@ -406,7 +426,7 @@ impl PipelineProcessor {
|
|||
}
|
||||
|
||||
/// Send request to a specific agent and return the response content
|
||||
async fn execute_filter(
|
||||
async fn execute_mcp_filter(
|
||||
&mut self,
|
||||
messages: &[Message],
|
||||
agent: &Agent,
|
||||
|
|
@ -420,11 +440,7 @@ impl PipelineProcessor {
|
|||
session_id.clone()
|
||||
} else {
|
||||
let session_id = self
|
||||
.get_new_session_id(
|
||||
&agent.id,
|
||||
trace_id.clone(),
|
||||
filter_span_id.clone(),
|
||||
)
|
||||
.get_new_session_id(&agent.id, trace_id.clone(), filter_span_id.clone())
|
||||
.await;
|
||||
self.agent_id_session_map
|
||||
.insert(agent.id.clone(), session_id.clone());
|
||||
|
|
@ -444,19 +460,20 @@ impl PipelineProcessor {
|
|||
let mcp_span_id = generate_random_span_id();
|
||||
|
||||
// Build headers
|
||||
let agent_headers =
|
||||
self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id), trace_id.clone(), mcp_span_id.clone())?;
|
||||
let agent_headers = self.build_mcp_headers(
|
||||
request_headers,
|
||||
&agent.id,
|
||||
Some(&mcp_session_id),
|
||||
trace_id.clone(),
|
||||
mcp_span_id.clone(),
|
||||
)?;
|
||||
|
||||
// Send request with tracing
|
||||
let start_time = SystemTime::now();
|
||||
let start_instant = Instant::now();
|
||||
|
||||
let response = self
|
||||
.send_mcp_request(
|
||||
&json_rpc_request,
|
||||
agent_headers,
|
||||
&agent.id,
|
||||
)
|
||||
.send_mcp_request(&json_rpc_request, agent_headers, &agent.id)
|
||||
.await?;
|
||||
let http_status = response.status();
|
||||
let response_bytes = response.bytes().await?;
|
||||
|
|
@ -598,7 +615,13 @@ impl PipelineProcessor {
|
|||
let notification_body = serde_json::to_string(&initialized_notification)?;
|
||||
debug!("Sending initialized notification for agent {}", agent_id);
|
||||
|
||||
let headers = self.build_mcp_headers(&HeaderMap::new(), agent_id, Some(session_id), trace_id.clone(), parent_span_id.clone())?;
|
||||
let headers = self.build_mcp_headers(
|
||||
&HeaderMap::new(),
|
||||
agent_id,
|
||||
Some(session_id),
|
||||
trace_id.clone(),
|
||||
parent_span_id.clone(),
|
||||
)?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
|
|
@ -626,7 +649,13 @@ impl PipelineProcessor {
|
|||
|
||||
let initialize_request = self.build_initialize_request();
|
||||
let headers = self
|
||||
.build_mcp_headers(&HeaderMap::new(), agent_id, None, trace_id.clone(), parent_span_id.clone())
|
||||
.build_mcp_headers(
|
||||
&HeaderMap::new(),
|
||||
agent_id,
|
||||
None,
|
||||
trace_id.clone(),
|
||||
parent_span_id.clone(),
|
||||
)
|
||||
.expect("Failed to build headers for initialization");
|
||||
|
||||
let response = self
|
||||
|
|
@ -661,6 +690,129 @@ impl PipelineProcessor {
|
|||
session_id
|
||||
}
|
||||
|
||||
/// Execute a REST-based filter agent
|
||||
async fn execute_rest_filter(
|
||||
&mut self,
|
||||
messages: &[Message],
|
||||
agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
trace_collector: Option<&std::sync::Arc<common::traces::TraceCollector>>,
|
||||
trace_id: String,
|
||||
filter_span_id: String,
|
||||
) -> Result<Vec<Message>, PipelineError> {
|
||||
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
|
||||
|
||||
// Generate span ID for this REST call (child of filter span)
|
||||
let rest_span_id = generate_random_span_id();
|
||||
|
||||
// Build headers
|
||||
let trace_parent = format!("00-{}-{}-01", trace_id, rest_span_id);
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
agent_headers.remove(TRACE_PARENT_HEADER);
|
||||
agent_headers.insert(
|
||||
TRACE_PARENT_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&trace_parent).unwrap(),
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&agent.id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
"Accept",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
"Content-Type",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
// Send request with tracing
|
||||
let start_time = SystemTime::now();
|
||||
let start_instant = Instant::now();
|
||||
|
||||
debug!(
|
||||
"Sending REST request to agent {} at URL: {}",
|
||||
agent.id, agent.url
|
||||
);
|
||||
|
||||
// Send messages array directly as request body
|
||||
let response = self
|
||||
.client
|
||||
.post(&agent.url)
|
||||
.headers(agent_headers)
|
||||
.json(&messages)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let http_status = response.status();
|
||||
let response_bytes = response.bytes().await?;
|
||||
|
||||
let end_time = SystemTime::now();
|
||||
let elapsed = start_instant.elapsed();
|
||||
|
||||
// Record REST call span
|
||||
if let Some(collector) = trace_collector {
|
||||
let mut attrs = HashMap::new();
|
||||
attrs.insert("rest.tool_name", tool_name.to_string());
|
||||
attrs.insert("rest.url", agent.url.clone());
|
||||
attrs.insert("http.status_code", http_status.as_u16().to_string());
|
||||
|
||||
self.record_mcp_span(
|
||||
collector,
|
||||
"rest_call",
|
||||
&agent.id,
|
||||
start_time,
|
||||
end_time,
|
||||
elapsed,
|
||||
Some(attrs),
|
||||
trace_id.clone(),
|
||||
filter_span_id.clone(),
|
||||
Some(rest_span_id),
|
||||
);
|
||||
}
|
||||
|
||||
// Handle HTTP errors
|
||||
if !http_status.is_success() {
|
||||
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
|
||||
return Err(if http_status.is_client_error() {
|
||||
PipelineError::ClientError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_body,
|
||||
}
|
||||
} else {
|
||||
PipelineError::ServerError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_body,
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
info!(
|
||||
"Response from REST agent {}: {}",
|
||||
agent.id,
|
||||
String::from_utf8_lossy(&response_bytes)
|
||||
);
|
||||
|
||||
// Parse response - expecting array of messages directly
|
||||
let messages: Vec<Message> =
|
||||
serde_json::from_slice(&response_bytes).map_err(PipelineError::ParseError)?;
|
||||
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
/// Send request to terminal agent and return the raw response for streaming
|
||||
pub async fn invoke_agent(
|
||||
&self,
|
||||
|
|
@ -734,7 +886,7 @@ mod tests {
|
|||
fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain {
|
||||
AgentFilterChain {
|
||||
id: "test-agent".to_string(),
|
||||
filter_chain: agents.iter().map(|s| s.to_string()).collect(),
|
||||
filter_chain: Some(agents.iter().map(|s| s.to_string()).collect()),
|
||||
description: None,
|
||||
default: None,
|
||||
}
|
||||
|
|
@ -751,7 +903,15 @@ mod tests {
|
|||
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
|
||||
|
||||
let result = processor
|
||||
.process_filter_chain(&messages, &pipeline, &agent_map, &request_headers, None, String::new(), String::new())
|
||||
.process_filter_chain(
|
||||
&messages,
|
||||
&pipeline,
|
||||
&agent_map,
|
||||
&request_headers,
|
||||
None,
|
||||
String::new(),
|
||||
String::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
|
|
@ -785,7 +945,14 @@ mod tests {
|
|||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_filter(&messages, &agent, &request_headers, None, "trace-123".to_string(), "span-123".to_string())
|
||||
.execute_mcp_filter(
|
||||
&messages,
|
||||
&agent,
|
||||
&request_headers,
|
||||
None,
|
||||
"trace-123".to_string(),
|
||||
"span-123".to_string(),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
|
|
@ -824,7 +991,14 @@ mod tests {
|
|||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_filter(&messages, &agent, &request_headers, None, "trace-456".to_string(), "span-456".to_string())
|
||||
.execute_mcp_filter(
|
||||
&messages,
|
||||
&agent,
|
||||
&request_headers,
|
||||
None,
|
||||
"trace-456".to_string(),
|
||||
"span-456".to_string(),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
|
|
@ -876,7 +1050,14 @@ mod tests {
|
|||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_filter(&messages, &agent, &request_headers, None, "trace-789".to_string(), "span-789".to_string())
|
||||
.execute_mcp_filter(
|
||||
&messages,
|
||||
&agent,
|
||||
&request_headers,
|
||||
None,
|
||||
"trace-789".to_string(),
|
||||
"span-789".to_string(),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
use bytes::Bytes;
|
||||
use hermesllm::apis::OpenAIApi;
|
||||
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use hermesllm::SseEvent;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
|
|
@ -6,7 +9,7 @@ use hyper::{Response, StatusCode};
|
|||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::warn;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Errors that can occur during response handling
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
|
@ -113,6 +116,74 @@ impl ResponseHandler {
|
|||
.body(stream_body)
|
||||
.map_err(ResponseError::from)
|
||||
}
|
||||
|
||||
/// Collect the full response body as a string
|
||||
/// This is used for intermediate agents where we need to capture the full response
|
||||
/// before passing it to the next agent.
|
||||
///
|
||||
/// This method handles both streaming and non-streaming responses:
|
||||
/// - For streaming SSE responses: parses chunks and extracts text deltas
|
||||
/// - For non-streaming responses: returns the full text
|
||||
pub async fn collect_full_response(
|
||||
&self,
|
||||
llm_response: reqwest::Response,
|
||||
) -> Result<String, ResponseError> {
|
||||
use hermesllm::apis::streaming_shapes::sse::SseStreamIter;
|
||||
|
||||
let response_headers = llm_response.headers();
|
||||
let is_sse_streaming = response_headers
|
||||
.get(hyper::header::CONTENT_TYPE)
|
||||
.map_or(false, |v| {
|
||||
v.to_str().unwrap_or("").contains("text/event-stream")
|
||||
});
|
||||
|
||||
let response_bytes = llm_response
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|e| ResponseError::StreamError(format!("Failed to read response: {}", e)))?;
|
||||
|
||||
if is_sse_streaming {
|
||||
let client_api =
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let upstream_api =
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
let sse_iter = SseStreamIter::try_from(response_bytes.as_ref()).unwrap();
|
||||
let mut accumulated_text = String::new();
|
||||
|
||||
for sse_event in sse_iter {
|
||||
// Skip [DONE] markers and event-only lines
|
||||
if sse_event.is_done() || sse_event.is_event_only() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let transformed_event =
|
||||
SseEvent::try_from((sse_event, &client_api, &upstream_api)).unwrap();
|
||||
|
||||
// Try to get provider response and extract content delta
|
||||
match transformed_event.provider_response() {
|
||||
Ok(provider_response) => {
|
||||
if let Some(content) = provider_response.content_delta() {
|
||||
accumulated_text.push_str(&content);
|
||||
} else {
|
||||
info!("No content delta in provider response");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse provider response: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Ok(accumulated_text);
|
||||
} else {
|
||||
// If not SSE, treat as regular text response
|
||||
let response_text = String::from_utf8(response_bytes.to_vec()).map_err(|e| {
|
||||
ResponseError::StreamError(format!("Failed to decode response: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(response_text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ResponseHandler {
|
||||
|
|
|
|||
|
|
@ -3,13 +3,14 @@ use brightstaff::handlers::function_calling::function_calling_chat_handler;
|
|||
use brightstaff::handlers::llm::llm_chat;
|
||||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::router::llm_router::RouterService;
|
||||
use brightstaff::router::plano_orchestrator::OrchestratorService;
|
||||
use brightstaff::state::StateStorage;
|
||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||
use brightstaff::state::memory::MemoryConversationalStorage;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{Agent, Configuration};
|
||||
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
|
||||
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME};
|
||||
use common::traces::TraceCollector;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
use hyper::body::Incoming;
|
||||
|
|
@ -95,10 +96,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
||||
arch_config.model_providers.clone(),
|
||||
llm_provider_url.clone() + CHAT_COMPLETIONS_PATH,
|
||||
routing_model_name,
|
||||
routing_llm_provider,
|
||||
routing_model_name.clone(),
|
||||
routing_llm_provider.clone(),
|
||||
));
|
||||
|
||||
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(
|
||||
llm_provider_url.clone() + CHAT_COMPLETIONS_PATH,
|
||||
PLANO_ORCHESTRATOR_MODEL_NAME.to_string(),
|
||||
));
|
||||
|
||||
|
||||
let model_aliases = Arc::new(arch_config.model_aliases.clone());
|
||||
|
||||
// Initialize trace collector and start background flusher
|
||||
|
|
@ -154,6 +161,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let io = TokioIo::new(stream);
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
||||
let orchestrator_service: Arc<OrchestratorService> = Arc::clone(&orchestrator_service);
|
||||
let model_aliases: Arc<
|
||||
Option<std::collections::HashMap<String, common::configuration::ModelAlias>>,
|
||||
> = Arc::clone(&model_aliases);
|
||||
|
|
@ -166,6 +174,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let state_storage = state_storage.clone();
|
||||
let service = service_fn(move |req| {
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let orchestrator_service = Arc::clone(&orchestrator_service);
|
||||
let parent_cx = extract_context_from_request(&req);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
let llm_providers = llm_providers.clone();
|
||||
|
|
@ -188,7 +197,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let fully_qualified_url = format!("{}{}", llm_provider_url, stripped_path);
|
||||
return agent_chat(
|
||||
req,
|
||||
router_service,
|
||||
orchestrator_service,
|
||||
fully_qualified_url,
|
||||
agents_list,
|
||||
listeners,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ pub struct RouterService {
|
|||
router_url: String,
|
||||
client: reqwest::Client,
|
||||
router_model: Arc<dyn RouterModel>,
|
||||
#[allow(dead_code)]
|
||||
routing_provider_name: String,
|
||||
llm_usage_defined: bool,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
pub mod llm_router;
|
||||
pub mod orchestrator_model;
|
||||
pub mod orchestrator_model_v1;
|
||||
pub mod plano_orchestrator;
|
||||
pub mod router_model;
|
||||
pub mod router_model_v1;
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ use tracing::{debug, warn};
|
|||
|
||||
use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError};
|
||||
|
||||
pub const MAX_TOKEN_LEN: usize = 2048; // Default max token length for the orchestration model
|
||||
|
||||
/// Custom JSON formatter that produces spaced JSON (space after colons and commas), same as JSON in python
|
||||
struct SpacedJsonFormatter;
|
||||
|
||||
|
|
|
|||
162
crates/brightstaff/src/router/plano_orchestrator.rs
Normal file
162
crates/brightstaff/src/router/plano_orchestrator.rs
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use common::{
|
||||
configuration::{AgentUsagePreference, OrchestrationPreference},
|
||||
consts::{ARCH_PROVIDER_HINT_HEADER, PLANO_ORCHESTRATOR_MODEL_NAME},
|
||||
};
|
||||
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
|
||||
use hyper::header;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::orchestrator_model_v1::{self};
|
||||
|
||||
use super::orchestrator_model::OrchestratorModel;
|
||||
|
||||
pub struct OrchestratorService {
|
||||
orchestrator_url: String,
|
||||
client: reqwest::Client,
|
||||
orchestrator_model: Arc<dyn OrchestratorModel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OrchestrationError {
|
||||
#[error("Failed to send request: {0}")]
|
||||
RequestError(#[from] reqwest::Error),
|
||||
|
||||
#[error("Failed to parse JSON: {0}, JSON: {1}")]
|
||||
JsonError(serde_json::Error, String),
|
||||
|
||||
#[error("Orchestrator model error: {0}")]
|
||||
OrchestratorModelError(#[from] super::orchestrator_model::OrchestratorModelError),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, OrchestrationError>;
|
||||
|
||||
impl OrchestratorService {
|
||||
pub fn new(
|
||||
orchestrator_url: String,
|
||||
orchestration_model_name: String,
|
||||
) -> Self {
|
||||
// Empty agent orchestrations - will be provided via usage_preferences in requests
|
||||
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
|
||||
|
||||
let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model_name.clone(),
|
||||
orchestrator_model_v1::MAX_TOKEN_LEN,
|
||||
));
|
||||
|
||||
OrchestratorService {
|
||||
orchestrator_url,
|
||||
client: reqwest::Client::new(),
|
||||
orchestrator_model,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn determine_orchestration(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
trace_parent: Option<String>,
|
||||
usage_preferences: Option<Vec<AgentUsagePreference>>,
|
||||
) -> Result<Option<Vec<(String, String)>>> {
|
||||
if messages.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Require usage_preferences to be provided
|
||||
if usage_preferences.is_none() || usage_preferences.as_ref().unwrap().is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let orchestrator_request = self
|
||||
.orchestrator_model
|
||||
.generate_request(messages, &usage_preferences);
|
||||
|
||||
debug!(
|
||||
"sending request to arch-orchestrator model: {}, endpoint: {}",
|
||||
self.orchestrator_model.get_model_name(),
|
||||
self.orchestrator_url
|
||||
);
|
||||
|
||||
debug!(
|
||||
"arch orchestrator request body: {}",
|
||||
&serde_json::to_string(&orchestrator_request).unwrap(),
|
||||
);
|
||||
|
||||
let mut orchestration_request_headers = header::HeaderMap::new();
|
||||
orchestration_request_headers.insert(
|
||||
header::CONTENT_TYPE,
|
||||
header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
orchestration_request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
|
||||
header::HeaderValue::from_str(PLANO_ORCHESTRATOR_MODEL_NAME).unwrap(),
|
||||
);
|
||||
|
||||
if let Some(trace_parent) = trace_parent {
|
||||
orchestration_request_headers.insert(
|
||||
header::HeaderName::from_static("traceparent"),
|
||||
header::HeaderValue::from_str(&trace_parent).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
orchestration_request_headers.insert(
|
||||
header::HeaderName::from_static("model"),
|
||||
header::HeaderValue::from_static(PLANO_ORCHESTRATOR_MODEL_NAME),
|
||||
);
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let res = self
|
||||
.client
|
||||
.post(&self.orchestrator_url)
|
||||
.headers(orchestration_request_headers)
|
||||
.body(serde_json::to_string(&orchestrator_request).unwrap())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let body = res.text().await?;
|
||||
let orchestrator_response_time = start_time.elapsed();
|
||||
|
||||
let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) {
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Failed to parse JSON: {}. Body: {}",
|
||||
err,
|
||||
&serde_json::to_string(&body).unwrap()
|
||||
);
|
||||
return Err(OrchestrationError::JsonError(
|
||||
err,
|
||||
format!("Failed to parse JSON: {}", body),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if chat_completion_response.choices.is_empty() {
|
||||
warn!("No choices in orchestrator response: {}", body);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if let Some(content) = &chat_completion_response.choices[0].message.content {
|
||||
let parsed_response = self
|
||||
.orchestrator_model
|
||||
.parse_response(content, &usage_preferences)?;
|
||||
info!(
|
||||
"arch-orchestrator determined routes: {}, selected_routes: {:?}, response time: {}ms",
|
||||
content.replace("\n", "\\n"),
|
||||
parsed_response,
|
||||
orchestrator_response_time.as_millis()
|
||||
);
|
||||
|
||||
if let Some(ref parsed_response) = parsed_response {
|
||||
return Ok(Some(parsed_response.clone()));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -150,9 +150,12 @@ pub mod operation_component {
|
|||
/// Inbound request handling
|
||||
pub const INBOUND: &str = "plano(inbound)";
|
||||
|
||||
/// Routing decision phase
|
||||
/// Orchestrator for llm route selection
|
||||
pub const ROUTING: &str = "plano(routing)";
|
||||
|
||||
/// Orchestrator for agent selection
|
||||
pub const ORCHESTRATOR: &str = "plano(orchestrator)";
|
||||
|
||||
/// Handoff to upstream service
|
||||
pub const HANDOFF: &str = "plano(handoff)";
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue