plano orchestration (draft)

This commit is contained in:
Adil Hafeez 2025-12-18 11:21:19 -08:00
parent 2f9121407b
commit 53e03901d2
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
25 changed files with 2520 additions and 285 deletions

View file

@ -17,7 +17,7 @@ use tracing::{debug, info, warn};
use super::agent_selector::{AgentSelectionError, AgentSelector};
use super::pipeline_processor::{PipelineError, PipelineProcessor};
use super::response_handler::ResponseHandler;
use crate::router::llm_router::RouterService;
use crate::router::plano_orchestrator::OrchestratorService;
use crate::tracing::{OperationNameBuilder, operation_component, http};
/// Main errors for agent chat completions
@ -37,7 +37,7 @@ pub enum AgentFilterChainError {
pub async fn agent_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
orchestrator_service: Arc<OrchestratorService>,
_: String,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
@ -45,7 +45,7 @@ pub async fn agent_chat(
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
match handle_agent_chat(
request,
router_service,
orchestrator_service,
agents_list,
listeners,
trace_collector,
@ -123,13 +123,13 @@ pub async fn agent_chat(
async fn handle_agent_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
orchestrator_service: Arc<OrchestratorService>,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
trace_collector: Arc<common::traces::TraceCollector>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
// Initialize services
let agent_selector = AgentSelector::new(router_service);
let agent_selector = AgentSelector::new(orchestrator_service);
let mut pipeline_processor = PipelineProcessor::default();
let response_handler = ResponseHandler::new();
@ -215,94 +215,123 @@ async fn handle_agent_chat(
(String::new(), None)
};
// Select appropriate agent using arch router llm model
let selected_agent = agent_selector
.select_agent(&message, &listener, trace_parent.clone())
// Select appropriate agents using arch orchestrator llm model
let selected_agents = agent_selector
.select_agents(&message, &listener, trace_parent.clone())
.await?;
debug!("Processing agent pipeline: {}", selected_agent.id);
info!("Selected {} agent(s) for execution", selected_agents.len());
// Record the start time for agent span
let agent_start_time = SystemTime::now();
let agent_start_instant = Instant::now();
// let (span_id, trace_id) = trace_collector.start_span(
// trace_parent.clone(),
// operation_component::AGENT,
// &format!("/agents{}", request_path),
// &selected_agent.id,
// );
// Execute agents sequentially, passing output from one to the next
let mut current_messages = message.clone();
let agent_count = selected_agents.len();
let span_id = generate_random_span_id();
for (agent_index, selected_agent) in selected_agents.iter().enumerate() {
let is_last_agent = agent_index == agent_count - 1;
debug!(
"Processing agent {}/{}: {}",
agent_index + 1,
agent_count,
selected_agent.id
);
// Process the filter chain
let chat_history = pipeline_processor
.process_filter_chain(
&message,
&selected_agent,
&agent_map,
&request_headers,
Some(&trace_collector),
trace_id.clone(),
span_id.clone(),
)
.await?;
// Record the start time for agent span
let agent_start_time = SystemTime::now();
let agent_start_instant = Instant::now();
let span_id = generate_random_span_id();
// Get terminal agent and send final response
let terminal_agent_name = selected_agent.id.clone();
let terminal_agent = agent_map.get(&terminal_agent_name).unwrap();
// Process the filter chain
let chat_history = pipeline_processor
.process_filter_chain(
&current_messages,
selected_agent,
&agent_map,
&request_headers,
Some(&trace_collector),
trace_id.clone(),
span_id.clone(),
)
.await?;
debug!("Processing terminal agent: {}", terminal_agent_name);
debug!("Terminal agent details: {:?}", terminal_agent);
// Get agent details and invoke
let agent_name = selected_agent.id.clone();
let agent = agent_map.get(&agent_name).unwrap();
let llm_response = pipeline_processor
.invoke_agent(
&chat_history,
client_request,
terminal_agent,
&request_headers,
trace_id.clone(),
span_id.clone(),
)
.await?;
debug!("Invoking agent: {}", agent_name);
// Record agent span after processing is complete
let agent_end_time = SystemTime::now();
let agent_elapsed = agent_start_instant.elapsed();
let llm_response = pipeline_processor
.invoke_agent(
&chat_history,
client_request.clone(),
agent,
&request_headers,
trace_id.clone(),
span_id.clone(),
)
.await?;
// Build full path with /agents prefix
let full_path = format!("/agents{}", request_path);
// Record agent span
let agent_end_time = SystemTime::now();
let agent_elapsed = agent_start_instant.elapsed();
let full_path = format!("/agents{}", request_path);
let operation_name = OperationNameBuilder::new()
.with_method("POST")
.with_path(&full_path)
.with_target(&agent_name)
.build();
// Build operation name: POST {full_path} {agent_name}
let operation_name = OperationNameBuilder::new()
.with_method("POST")
.with_path(&full_path)
.with_target(&terminal_agent_name)
.build();
let mut span_builder = SpanBuilder::new(&operation_name)
.with_span_id(span_id)
.with_kind(SpanKind::Internal)
.with_start_time(agent_start_time)
.with_end_time(agent_end_time)
.with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, full_path)
.with_attribute("agent.name", agent_name.clone())
.with_attribute("agent.sequence", format!("{}/{}", agent_index + 1, agent_count))
.with_attribute("duration_ms", format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0));
let mut span_builder = SpanBuilder::new(&operation_name)
.with_span_id(span_id)
.with_kind(SpanKind::Internal)
.with_start_time(agent_start_time)
.with_end_time(agent_end_time)
.with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, full_path)
.with_attribute("agent.name", terminal_agent_name.clone())
.with_attribute("duration_ms", format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0));
if !trace_id.is_empty() {
span_builder = span_builder.with_trace_id(trace_id.clone());
}
if let Some(parent_id) = parent_span_id.clone() {
span_builder = span_builder.with_parent_span_id(parent_id);
}
if !trace_id.is_empty() {
span_builder = span_builder.with_trace_id(trace_id);
}
if let Some(parent_id) = parent_span_id {
span_builder = span_builder.with_parent_span_id(parent_id);
let span = span_builder.build();
trace_collector.record_span(operation_component::AGENT, span);
// If this is the last agent, return the streaming response
if is_last_agent {
info!("Completed agent chain, returning response from last agent: {}", agent_name);
return response_handler
.create_streaming_response(llm_response)
.await
.map_err(AgentFilterChainError::from);
}
// For intermediate agents, collect the full response and pass to next agent
debug!("Collecting response from intermediate agent: {}", agent_name);
let response_text = response_handler.collect_full_response(llm_response).await?;
// Create a new message with the agent's response as assistant message
// and add it to the conversation history
current_messages.push(OpenAIMessage {
role: hermesllm::apis::openai::Role::Assistant,
content: hermesllm::apis::openai::MessageContent::Text(response_text.clone()),
name: Some(agent_name.clone()),
tool_calls: None,
tool_call_id: None,
});
info!(
"Agent {} completed, passing {} character response to next agent",
agent_name,
response_text.len()
);
}
let span = span_builder.build();
// Use plano(agent) as service name for the agent processing span
trace_collector.record_span(operation_component::AGENT, span);
// Create streaming response
response_handler
.create_streaming_response(llm_response)
.await
.map_err(AgentFilterChainError::from)
// This should never be reached since we return in the last agent iteration
unreachable!("Agent execution loop should have returned a response")
}

View file

@ -2,12 +2,12 @@ use std::collections::HashMap;
use std::sync::Arc;
use common::configuration::{
Agent, AgentFilterChain, Listener, ModelUsagePreference, RoutingPreference,
Agent, AgentFilterChain, Listener, AgentUsagePreference, OrchestrationPreference,
};
use hermesllm::apis::openai::Message;
use tracing::{debug, warn};
use crate::router::llm_router::RouterService;
use crate::router::plano_orchestrator::OrchestratorService;
/// Errors that can occur during agent selection
#[derive(Debug, thiserror::Error)]
@ -16,23 +16,23 @@ pub enum AgentSelectionError {
ListenerNotFound(String),
#[error("No agents configured for listener: {0}")]
NoAgentsConfigured(String),
#[error("Routing service error: {0}")]
RoutingError(String),
#[error("Default agent not found for listener: {0}")]
DefaultAgentNotFound(String),
#[error("MCP client error: {0}")]
McpError(String),
#[error("Orchestration service error: {0}")]
OrchestrationError(String),
}
/// Service for selecting agents based on routing preferences and listener configuration
/// Service for selecting agents based on orchestration preferences and listener configuration
pub struct AgentSelector {
router_service: Arc<RouterService>,
orchestrator_service: Arc<OrchestratorService>,
}
impl AgentSelector {
pub fn new(router_service: Arc<RouterService>) -> Self {
pub fn new(orchestrator_service: Arc<OrchestratorService>) -> Self {
Self {
router_service,
orchestrator_service,
}
}
@ -63,59 +63,6 @@ impl AgentSelector {
.collect()
}
/// Select appropriate agent based on routing preferences
pub async fn select_agent(
&self,
messages: &[Message],
listener: &Listener,
trace_parent: Option<String>,
) -> Result<AgentFilterChain, AgentSelectionError> {
let agents = listener
.agents
.as_ref()
.ok_or_else(|| AgentSelectionError::NoAgentsConfigured(listener.name.clone()))?;
// If only one agent, skip routing
if agents.len() == 1 {
debug!("Only one agent available, skipping routing");
return Ok(agents[0].clone());
}
let usage_preferences = self
.convert_agent_description_to_routing_preferences(agents)
.await;
debug!(
"Agents usage preferences for agent routing str: {}",
serde_json::to_string(&usage_preferences).unwrap_or_default()
);
match self
.router_service
.determine_route(messages, trace_parent, Some(usage_preferences))
.await
{
Ok(Some((_, agent_name))) => {
debug!("Determined agent: {}", agent_name);
let selected_agent = agents
.iter()
.find(|a| a.id == agent_name)
.cloned()
.ok_or_else(|| {
AgentSelectionError::RoutingError(format!(
"Selected agent '{}' not found in listener agents",
agent_name
))
})?;
Ok(selected_agent)
}
Ok(None) => {
debug!("No agent determined using routing preferences, using default agent");
self.get_default_agent(agents, &listener.name)
}
Err(err) => Err(AgentSelectionError::RoutingError(err.to_string())),
}
}
/// Get the default agent or the first agent if no default is specified
fn get_default_agent(
&self,
@ -136,17 +83,17 @@ impl AgentSelector {
.ok_or_else(|| AgentSelectionError::DefaultAgentNotFound(listener_name.to_string()))
}
/// Convert agent descriptions to routing preferences
async fn convert_agent_description_to_routing_preferences(
/// Convert agent descriptions to orchestration preferences
async fn convert_agent_description_to_orchestration_preferences(
&self,
agents: &[AgentFilterChain],
) -> Vec<ModelUsagePreference> {
) -> Vec<AgentUsagePreference> {
let mut preferences = Vec::new();
for agent_chain in agents {
preferences.push(ModelUsagePreference {
preferences.push(AgentUsagePreference {
model: agent_chain.id.clone(),
routing_preferences: vec![RoutingPreference {
orchestration_preferences: vec![OrchestrationPreference {
name: agent_chain.id.clone(),
description: agent_chain.description.clone().unwrap_or_default(),
}],
@ -155,6 +102,71 @@ impl AgentSelector {
preferences
}
/// Select multiple agents using orchestration
pub async fn select_agents(
&self,
messages: &[Message],
listener: &Listener,
trace_parent: Option<String>,
) -> Result<Vec<AgentFilterChain>, AgentSelectionError> {
let agents = listener
.agents
.as_ref()
.ok_or_else(|| AgentSelectionError::NoAgentsConfigured(listener.name.clone()))?;
// If only one agent, skip orchestration
if agents.len() == 1 {
debug!("Only one agent available, skipping orchestration");
return Ok(vec![agents[0].clone()]);
}
let usage_preferences = self
.convert_agent_description_to_orchestration_preferences(agents)
.await;
debug!(
"Agents usage preferences for orchestration: {}",
serde_json::to_string(&usage_preferences).unwrap_or_default()
);
match self
.orchestrator_service
.determine_orchestration(messages, trace_parent, Some(usage_preferences))
.await
{
Ok(Some(routes)) => {
debug!("Determined {} agent(s) via orchestration", routes.len());
let mut selected_agents = Vec::new();
for (route_name, agent_name) in routes {
debug!("Processing route: {}, agent: {}", route_name, agent_name);
let selected_agent = agents
.iter()
.find(|a| a.id == agent_name)
.cloned()
.ok_or_else(|| {
AgentSelectionError::OrchestrationError(format!(
"Selected agent '{}' not found in listener agents",
agent_name
))
})?;
selected_agents.push(selected_agent);
}
if selected_agents.is_empty() {
debug!("No agents determined using orchestration, using default agent");
Ok(vec![self.get_default_agent(agents, &listener.name)?])
} else {
Ok(selected_agents)
}
}
Ok(None) => {
debug!("No agents determined using orchestration, using default agent");
Ok(vec![self.get_default_agent(agents, &listener.name)?])
}
Err(err) => Err(AgentSelectionError::OrchestrationError(err.to_string())),
}
}
}
#[cfg(test)]
@ -162,8 +174,8 @@ mod tests {
use super::*;
use common::configuration::{AgentFilterChain, Listener};
fn create_test_router_service() -> Arc<RouterService> {
Arc::new(RouterService::new(
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
Arc::new(OrchestratorService::new(
vec![], // empty providers for testing
"http://localhost:8080".to_string(),
"test-model".to_string(),
@ -201,8 +213,8 @@ mod tests {
#[tokio::test]
async fn test_find_listener_success() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let orchestrator_service = create_test_orchestrator_service();
let selector = AgentSelector::new(orchestrator_service);
let listener1 = create_test_listener("test-listener", vec![]);
let listener2 = create_test_listener("other-listener", vec![]);
@ -218,8 +230,8 @@ mod tests {
#[tokio::test]
async fn test_find_listener_not_found() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let orchestrator_service = create_test_orchestrator_service();
let selector = AgentSelector::new(orchestrator_service);
let listeners = vec![create_test_listener("other-listener", vec![])];
@ -236,8 +248,8 @@ mod tests {
#[test]
fn test_create_agent_map() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let orchestrator_service = create_test_orchestrator_service();
let selector = AgentSelector::new(orchestrator_service);
let agents = vec![
create_test_agent_struct("agent1"),
@ -251,33 +263,10 @@ mod tests {
assert!(agent_map.contains_key("agent2"));
}
#[tokio::test]
async fn test_convert_agent_description_to_routing_preferences() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let agents = vec![
create_test_agent("agent1", "First agent description", true),
create_test_agent("agent2", "Second agent description", false),
];
let preferences = selector
.convert_agent_description_to_routing_preferences(&agents)
.await;
assert_eq!(preferences.len(), 2);
assert_eq!(preferences[0].model, "agent1");
assert_eq!(preferences[0].routing_preferences[0].name, "agent1");
assert_eq!(
preferences[0].routing_preferences[0].description,
"First agent description"
);
}
#[test]
fn test_get_default_agent() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let orchestrator_service = create_test_orchestrator_service();
let selector = AgentSelector::new(orchestrator_service);
let agents = vec![
create_test_agent("agent1", "First agent", false),
@ -293,8 +282,8 @@ mod tests {
#[test]
fn test_get_default_agent_fallback_to_first() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let orchestrator_service = create_test_orchestrator_service();
let selector = AgentSelector::new(orchestrator_service);
let agents = vec![
create_test_agent("agent1", "First agent", false),

View file

@ -6,11 +6,11 @@ use hyper::header::HeaderMap;
use crate::handlers::agent_selector::{AgentSelectionError, AgentSelector};
use crate::handlers::pipeline_processor::PipelineProcessor;
use crate::handlers::response_handler::ResponseHandler;
use crate::router::llm_router::RouterService;
use crate::router::plano_orchestrator::OrchestratorService;
/// Integration test that demonstrates the modular agent chat flow
/// This test shows how the three main components work together:
/// 1. AgentSelector - selects the appropriate agent based on routing
/// 1. AgentSelector - selects the appropriate agents based on orchestration
/// 2. PipelineProcessor - executes the agent pipeline
/// 3. ResponseHandler - handles response streaming
#[cfg(test)]
@ -18,8 +18,8 @@ mod integration_tests {
use super::*;
use common::configuration::{Agent, AgentFilterChain, Listener};
fn create_test_router_service() -> Arc<RouterService> {
Arc::new(RouterService::new(
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
Arc::new(OrchestratorService::new(
vec![], // empty providers for testing
"http://localhost:8080".to_string(),
"test-model".to_string(),
@ -40,8 +40,8 @@ mod integration_tests {
#[tokio::test]
async fn test_modular_agent_chat_flow() {
// Setup services
let router_service = create_test_router_service();
let agent_selector = AgentSelector::new(router_service);
let orchestrator_service = create_test_orchestrator_service();
let agent_selector = AgentSelector::new(orchestrator_service);
let mut pipeline_processor = PipelineProcessor::default();
// Create test data

View file

@ -200,7 +200,13 @@ impl PipelineProcessor {
) -> Result<Vec<Message>, PipelineError> {
let mut chat_history_updated = chat_history.to_vec();
for agent_name in &agent_filter_chain.filter_chain {
// If filter_chain is None or empty, proceed without filtering
let filter_chain = match agent_filter_chain.filter_chain.as_ref() {
Some(fc) if !fc.is_empty() => fc,
_ => return Ok(chat_history_updated),
};
for agent_name in filter_chain {
debug!("Processing filter agent: {}", agent_name);
let agent = agent_map

View file

@ -113,6 +113,52 @@ impl ResponseHandler {
.body(stream_body)
.map_err(ResponseError::from)
}
/// Collect the full response body as a string
/// This is used for intermediate agents where we need to capture the full response
/// before passing it to the next agent.
///
/// This method handles both streaming and non-streaming responses:
/// - For streaming SSE responses: parses chunks and extracts text deltas
/// - For non-streaming responses: returns the full text
pub async fn collect_full_response(
&self,
llm_response: reqwest::Response,
) -> Result<String, ResponseError> {
use hermesllm::apis::streaming_shapes::sse::SseStreamIter;
let response_bytes = llm_response
.bytes()
.await
.map_err(|e| ResponseError::StreamError(format!("Failed to read response: {}", e)))?;
// Try to parse as SSE streaming response
if let Ok(sse_iter) = SseStreamIter::try_from(response_bytes.as_ref()) {
let mut accumulated_text = String::new();
for sse_event in sse_iter {
// Skip [DONE] markers and event-only lines
if sse_event.is_done() || sse_event.is_event_only() {
continue;
}
// Try to get provider response and extract content delta
if let Ok(provider_response) = sse_event.provider_response() {
if let Some(content) = provider_response.content_delta() {
accumulated_text.push_str(&content);
}
}
}
return Ok(accumulated_text);
}
// If not SSE, treat as regular text response
let response_text = String::from_utf8(response_bytes.to_vec())
.map_err(|e| ResponseError::StreamError(format!("Failed to decode response: {}", e)))?;
Ok(response_text)
}
}
impl Default for ResponseHandler {

View file

@ -3,6 +3,7 @@ use brightstaff::handlers::function_calling::function_calling_chat_handler;
use brightstaff::handlers::llm::llm_chat;
use brightstaff::handlers::models::list_models;
use brightstaff::router::llm_router::RouterService;
use brightstaff::router::plano_orchestrator::OrchestratorService;
use brightstaff::state::StateStorage;
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::memory::MemoryConversationalStorage;
@ -95,10 +96,18 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
arch_config.model_providers.clone(),
llm_provider_url.clone() + CHAT_COMPLETIONS_PATH,
routing_model_name,
routing_model_name.clone(),
routing_llm_provider.clone(),
));
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(
arch_config.model_providers.clone(),
llm_provider_url.clone() + CHAT_COMPLETIONS_PATH,
"Plano-Orchestrator".to_string(),
routing_llm_provider,
));
let model_aliases = Arc::new(arch_config.model_aliases.clone());
// Initialize trace collector and start background flusher
@ -154,6 +163,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let io = TokioIo::new(stream);
let router_service: Arc<RouterService> = Arc::clone(&router_service);
let orchestrator_service: Arc<OrchestratorService> = Arc::clone(&orchestrator_service);
let model_aliases: Arc<
Option<std::collections::HashMap<String, common::configuration::ModelAlias>>,
> = Arc::clone(&model_aliases);
@ -166,6 +176,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let state_storage = state_storage.clone();
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
let orchestrator_service = Arc::clone(&orchestrator_service);
let parent_cx = extract_context_from_request(&req);
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
@ -188,7 +199,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let fully_qualified_url = format!("{}{}", llm_provider_url, stripped_path);
return agent_chat(
req,
router_service,
orchestrator_service,
fully_qualified_url,
agents_list,
listeners,

View file

@ -1,5 +1,6 @@
pub mod llm_router;
pub mod orchestrator_model;
pub mod orchestrator_model_v1;
pub mod plano_orchestrator;
pub mod router_model;
pub mod router_model_v1;
pub mod router_model_v1;

View file

@ -7,6 +7,8 @@ use tracing::{debug, warn};
use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError};
pub const MAX_TOKEN_LEN: usize = 2048; // Default max token length for the orchestration model
/// Custom JSON formatter that produces spaced JSON (space after colons and commas), same as JSON in python
struct SpacedJsonFormatter;

View file

@ -0,0 +1,166 @@
use std::{collections::HashMap, sync::Arc};
use common::{
configuration::{LlmProvider, AgentUsagePreference, OrchestrationPreference},
consts::ARCH_PROVIDER_HINT_HEADER,
};
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
use hyper::header;
use thiserror::Error;
use tracing::{debug, info, warn};
use crate::router::orchestrator_model_v1::{self};
use super::orchestrator_model::OrchestratorModel;
pub struct OrchestratorService {
orchestrator_url: String,
client: reqwest::Client,
orchestrator_model: Arc<dyn OrchestratorModel>,
orchestration_provider_name: String,
}
#[derive(Debug, Error)]
pub enum OrchestrationError {
#[error("Failed to send request: {0}")]
RequestError(#[from] reqwest::Error),
#[error("Failed to parse JSON: {0}, JSON: {1}")]
JsonError(serde_json::Error, String),
#[error("Orchestrator model error: {0}")]
OrchestratorModelError(#[from] super::orchestrator_model::OrchestratorModelError),
}
pub type Result<T> = std::result::Result<T, OrchestrationError>;
impl OrchestratorService {
pub fn new(
_providers: Vec<LlmProvider>,
orchestrator_url: String,
orchestration_model_name: String,
orchestration_provider_name: String,
) -> Self {
// Empty agent orchestrations - will be provided via usage_preferences in requests
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new(
agent_orchestrations,
orchestration_model_name.clone(),
orchestrator_model_v1::MAX_TOKEN_LEN,
));
OrchestratorService {
orchestrator_url,
client: reqwest::Client::new(),
orchestrator_model,
orchestration_provider_name,
}
}
pub async fn determine_orchestration(
&self,
messages: &[Message],
trace_parent: Option<String>,
usage_preferences: Option<Vec<AgentUsagePreference>>,
) -> Result<Option<Vec<(String, String)>>> {
if messages.is_empty() {
return Ok(None);
}
// Require usage_preferences to be provided
if usage_preferences.is_none() || usage_preferences.as_ref().unwrap().is_empty() {
return Ok(None);
}
let orchestrator_request = self
.orchestrator_model
.generate_request(messages, &usage_preferences);
debug!(
"sending request to arch-orchestrator model: {}, endpoint: {}",
self.orchestrator_model.get_model_name(),
self.orchestrator_url
);
debug!(
"arch orchestrator request body: {}",
&serde_json::to_string(&orchestrator_request).unwrap(),
);
let mut orchestration_request_headers = header::HeaderMap::new();
orchestration_request_headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
orchestration_request_headers.insert(
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
header::HeaderValue::from_str(&self.orchestration_provider_name).unwrap(),
);
if let Some(trace_parent) = trace_parent {
orchestration_request_headers.insert(
header::HeaderName::from_static("traceparent"),
header::HeaderValue::from_str(&trace_parent).unwrap(),
);
}
orchestration_request_headers.insert(
header::HeaderName::from_static("model"),
header::HeaderValue::from_static("Plano-Orchestrator"),
);
let start_time = std::time::Instant::now();
let res = self
.client
.post(&self.orchestrator_url)
.headers(orchestration_request_headers)
.body(serde_json::to_string(&orchestrator_request).unwrap())
.send()
.await?;
let body = res.text().await?;
let orchestrator_response_time = start_time.elapsed();
let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) {
Ok(response) => response,
Err(err) => {
warn!(
"Failed to parse JSON: {}. Body: {}",
err,
&serde_json::to_string(&body).unwrap()
);
return Err(OrchestrationError::JsonError(
err,
format!("Failed to parse JSON: {}", body),
));
}
};
if chat_completion_response.choices.is_empty() {
warn!("No choices in orchestrator response: {}", body);
return Ok(None);
}
if let Some(content) = &chat_completion_response.choices[0].message.content {
let parsed_response = self
.orchestrator_model
.parse_response(content, &usage_preferences)?;
info!(
"arch-orchestrator determined routes: {}, selected_routes: {:?}, response time: {}ms",
content.replace("\n", "\\n"),
parsed_response,
orchestrator_response_time.as_millis()
);
if let Some(ref parsed_response) = parsed_response {
return Ok(Some(parsed_response.clone()));
}
Ok(None)
} else {
Ok(None)
}
}
}

View file

@ -33,7 +33,7 @@ pub struct AgentFilterChain {
pub id: String,
pub default: Option<bool>,
pub description: Option<String>,
pub filter_chain: Vec<String>,
pub filter_chain: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]