plano orchestration using plano orchestration 4b model (#637)

This commit is contained in:
Adil Hafeez 2025-12-22 18:05:49 -08:00 committed by GitHub
parent 60162e0575
commit 15fbb6c3af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
40 changed files with 4054 additions and 449 deletions

View file

@ -4,7 +4,7 @@ use common::configuration::{Agent, AgentFilterChain};
use common::consts::{
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
};
use common::traces::{SpanBuilder, SpanKind, generate_random_span_id};
use common::traces::{generate_random_span_id, SpanBuilder, SpanKind};
use hermesllm::apis::openai::Message;
use hermesllm::{ProviderRequest, ProviderRequestType};
use hyper::header::HeaderMap;
@ -200,7 +200,13 @@ impl PipelineProcessor {
) -> Result<Vec<Message>, PipelineError> {
let mut chat_history_updated = chat_history.to_vec();
for agent_name in &agent_filter_chain.filter_chain {
// If filter_chain is None or empty, proceed without filtering
let filter_chain = match agent_filter_chain.filter_chain.as_ref() {
Some(fc) if !fc.is_empty() => fc,
_ => return Ok(chat_history_updated),
};
for agent_name in filter_chain {
debug!("Processing filter agent: {}", agent_name);
let agent = agent_map
@ -210,10 +216,11 @@ impl PipelineProcessor {
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
info!(
"executing filter: {}/{}, url: {}, conversation length: {}",
"executing filter: {}/{}, url: {}, type: {}, conversation length: {}",
agent_name,
tool_name,
agent.url,
agent.agent_type.as_deref().unwrap_or("mcp"),
chat_history.len()
);
@ -223,16 +230,29 @@ impl PipelineProcessor {
// Generate filter span ID before execution so MCP spans can use it as parent
let filter_span_id = generate_random_span_id();
chat_history_updated = self
.execute_filter(
&chat_history_updated,
agent,
request_headers,
trace_collector,
trace_id.clone(),
filter_span_id.clone(),
)
.await?;
if agent.agent_type.as_deref().unwrap_or("mcp") == "mcp" {
chat_history_updated = self
.execute_mcp_filter(
&chat_history_updated,
agent,
request_headers,
trace_collector,
trace_id.clone(),
filter_span_id.clone(),
)
.await?;
} else {
chat_history_updated = self
.execute_rest_filter(
&chat_history_updated,
agent,
request_headers,
trace_collector,
trace_id.clone(),
filter_span_id.clone(),
)
.await?;
}
let end_time = SystemTime::now();
let elapsed = start_instant.elapsed();
@ -406,7 +426,7 @@ impl PipelineProcessor {
}
/// Send request to a specific agent and return the response content
async fn execute_filter(
async fn execute_mcp_filter(
&mut self,
messages: &[Message],
agent: &Agent,
@ -420,11 +440,7 @@ impl PipelineProcessor {
session_id.clone()
} else {
let session_id = self
.get_new_session_id(
&agent.id,
trace_id.clone(),
filter_span_id.clone(),
)
.get_new_session_id(&agent.id, trace_id.clone(), filter_span_id.clone())
.await;
self.agent_id_session_map
.insert(agent.id.clone(), session_id.clone());
@ -444,19 +460,20 @@ impl PipelineProcessor {
let mcp_span_id = generate_random_span_id();
// Build headers
let agent_headers =
self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id), trace_id.clone(), mcp_span_id.clone())?;
let agent_headers = self.build_mcp_headers(
request_headers,
&agent.id,
Some(&mcp_session_id),
trace_id.clone(),
mcp_span_id.clone(),
)?;
// Send request with tracing
let start_time = SystemTime::now();
let start_instant = Instant::now();
let response = self
.send_mcp_request(
&json_rpc_request,
agent_headers,
&agent.id,
)
.send_mcp_request(&json_rpc_request, agent_headers, &agent.id)
.await?;
let http_status = response.status();
let response_bytes = response.bytes().await?;
@ -598,7 +615,13 @@ impl PipelineProcessor {
let notification_body = serde_json::to_string(&initialized_notification)?;
debug!("Sending initialized notification for agent {}", agent_id);
let headers = self.build_mcp_headers(&HeaderMap::new(), agent_id, Some(session_id), trace_id.clone(), parent_span_id.clone())?;
let headers = self.build_mcp_headers(
&HeaderMap::new(),
agent_id,
Some(session_id),
trace_id.clone(),
parent_span_id.clone(),
)?;
let response = self
.client
@ -626,7 +649,13 @@ impl PipelineProcessor {
let initialize_request = self.build_initialize_request();
let headers = self
.build_mcp_headers(&HeaderMap::new(), agent_id, None, trace_id.clone(), parent_span_id.clone())
.build_mcp_headers(
&HeaderMap::new(),
agent_id,
None,
trace_id.clone(),
parent_span_id.clone(),
)
.expect("Failed to build headers for initialization");
let response = self
@ -661,6 +690,129 @@ impl PipelineProcessor {
session_id
}
/// Execute a REST-based filter agent
async fn execute_rest_filter(
&mut self,
messages: &[Message],
agent: &Agent,
request_headers: &HeaderMap,
trace_collector: Option<&std::sync::Arc<common::traces::TraceCollector>>,
trace_id: String,
filter_span_id: String,
) -> Result<Vec<Message>, PipelineError> {
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
// Generate span ID for this REST call (child of filter span)
let rest_span_id = generate_random_span_id();
// Build headers
let trace_parent = format!("00-{}-{}-01", trace_id, rest_span_id);
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
agent_headers.remove(TRACE_PARENT_HEADER);
agent_headers.insert(
TRACE_PARENT_HEADER,
hyper::header::HeaderValue::from_str(&trace_parent).unwrap(),
);
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&agent.id)
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
agent_headers.insert(
"Accept",
hyper::header::HeaderValue::from_static("application/json"),
);
agent_headers.insert(
"Content-Type",
hyper::header::HeaderValue::from_static("application/json"),
);
// Send request with tracing
let start_time = SystemTime::now();
let start_instant = Instant::now();
debug!(
"Sending REST request to agent {} at URL: {}",
agent.id, agent.url
);
// Send messages array directly as request body
let response = self
.client
.post(&agent.url)
.headers(agent_headers)
.json(&messages)
.send()
.await?;
let http_status = response.status();
let response_bytes = response.bytes().await?;
let end_time = SystemTime::now();
let elapsed = start_instant.elapsed();
// Record REST call span
if let Some(collector) = trace_collector {
let mut attrs = HashMap::new();
attrs.insert("rest.tool_name", tool_name.to_string());
attrs.insert("rest.url", agent.url.clone());
attrs.insert("http.status_code", http_status.as_u16().to_string());
self.record_mcp_span(
collector,
"rest_call",
&agent.id,
start_time,
end_time,
elapsed,
Some(attrs),
trace_id.clone(),
filter_span_id.clone(),
Some(rest_span_id),
);
}
// Handle HTTP errors
if !http_status.is_success() {
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
return Err(if http_status.is_client_error() {
PipelineError::ClientError {
agent: agent.id.clone(),
status: http_status.as_u16(),
body: error_body,
}
} else {
PipelineError::ServerError {
agent: agent.id.clone(),
status: http_status.as_u16(),
body: error_body,
}
});
}
info!(
"Response from REST agent {}: {}",
agent.id,
String::from_utf8_lossy(&response_bytes)
);
// Parse response - expecting array of messages directly
let messages: Vec<Message> =
serde_json::from_slice(&response_bytes).map_err(PipelineError::ParseError)?;
Ok(messages)
}
/// Send request to terminal agent and return the raw response for streaming
pub async fn invoke_agent(
&self,
@ -734,7 +886,7 @@ mod tests {
fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain {
AgentFilterChain {
id: "test-agent".to_string(),
filter_chain: agents.iter().map(|s| s.to_string()).collect(),
filter_chain: Some(agents.iter().map(|s| s.to_string()).collect()),
description: None,
default: None,
}
@ -751,7 +903,15 @@ mod tests {
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
let result = processor
.process_filter_chain(&messages, &pipeline, &agent_map, &request_headers, None, String::new(), String::new())
.process_filter_chain(
&messages,
&pipeline,
&agent_map,
&request_headers,
None,
String::new(),
String::new(),
)
.await;
assert!(result.is_err());
@ -785,7 +945,14 @@ mod tests {
let request_headers = HeaderMap::new();
let result = processor
.execute_filter(&messages, &agent, &request_headers, None, "trace-123".to_string(), "span-123".to_string())
.execute_mcp_filter(
&messages,
&agent,
&request_headers,
None,
"trace-123".to_string(),
"span-123".to_string(),
)
.await;
match result {
@ -824,7 +991,14 @@ mod tests {
let request_headers = HeaderMap::new();
let result = processor
.execute_filter(&messages, &agent, &request_headers, None, "trace-456".to_string(), "span-456".to_string())
.execute_mcp_filter(
&messages,
&agent,
&request_headers,
None,
"trace-456".to_string(),
"span-456".to_string(),
)
.await;
match result {
@ -876,7 +1050,14 @@ mod tests {
let request_headers = HeaderMap::new();
let result = processor
.execute_filter(&messages, &agent, &request_headers, None, "trace-789".to_string(), "span-789".to_string())
.execute_mcp_filter(
&messages,
&agent,
&request_headers,
None,
"trace-789".to_string(),
"span-789".to_string(),
)
.await;
match result {