more changes

This commit is contained in:
Adil Hafeez 2025-11-28 11:34:43 -08:00
parent dcfc85ca74
commit 4140c1cde4
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
15 changed files with 883 additions and 109 deletions

View file

@ -124,20 +124,25 @@ async fn handle_agent_chat(
.find(|(key, _)| key.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
// Select appropriate agent using arch router llm model
let selected_agent = agent_selector
.select_agent(&chat_completions_request.messages, &listener, trace_parent)
.await?;
debug!("Processing agent pipeline: {}", selected_agent.id);
// Create agent map for pipeline processing
// Create agent map for pipeline processing and agent selection
let agent_map = {
let agents = agents_list.read().await;
let agents = agents.as_ref().unwrap();
agent_selector.create_agent_map(agents)
};
// Select appropriate agent using arch router llm model
let selected_agent = agent_selector
.select_agent(
&chat_completions_request.messages,
&listener,
trace_parent,
&agent_map,
)
.await?;
debug!("Processing agent pipeline: {}", selected_agent.id);
// Process the filter chain
let processed_messages = pipeline_processor
.process_filter_chain(

View file

@ -8,6 +8,7 @@ use hermesllm::apis::openai::Message;
use tracing::{debug, warn};
use crate::router::llm_router::RouterService;
use crate::utils::mcp_client::McpClient;
/// Errors that can occur during agent selection
#[derive(Debug, thiserror::Error)]
@ -20,16 +21,22 @@ pub enum AgentSelectionError {
RoutingError(String),
#[error("Default agent not found for listener: {0}")]
DefaultAgentNotFound(String),
#[error("MCP client error: {0}")]
McpError(String),
}
/// Service for selecting agents based on routing preferences and listener configuration
pub struct AgentSelector {
router_service: Arc<RouterService>,
mcp_client: McpClient,
}
impl AgentSelector {
pub fn new(router_service: Arc<RouterService>) -> Self {
Self { router_service }
Self {
router_service,
mcp_client: McpClient::new(),
}
}
/// Find listener by name from the request headers
@ -65,6 +72,7 @@ impl AgentSelector {
messages: &[Message],
listener: &Listener,
trace_parent: Option<String>,
agent_map: &HashMap<String, Agent>,
) -> Result<AgentFilterChain, AgentSelectionError> {
let agents = listener
.agents
@ -77,7 +85,9 @@ impl AgentSelector {
return Ok(agents[0].clone());
}
let usage_preferences = self.convert_agent_description_to_routing_preferences(agents);
let usage_preferences = self
.convert_agent_description_to_routing_preferences(agents, agent_map)
.await;
debug!(
"Agents usage preferences for agent routing str: {}",
serde_json::to_string(&usage_preferences).unwrap_or_default()
@ -131,20 +141,75 @@ impl AgentSelector {
}
/// Convert agent descriptions to routing preferences
fn convert_agent_description_to_routing_preferences(
/// For agents with MCP URLs, fetches the tool description from the MCP server
async fn convert_agent_description_to_routing_preferences(
&self,
agents: &[AgentFilterChain],
agent_map: &HashMap<String, Agent>,
) -> Vec<ModelUsagePreference> {
agents
.iter()
.map(|agent| ModelUsagePreference {
model: agent.id.clone(),
let mut preferences = Vec::new();
for agent_chain in agents {
// Get the actual agent from the agent_map
let agent = agent_map.get(&agent_chain.id);
// Determine the description to use
let description = if let Some(agent) = agent {
// Check if this is an MCP agent (URL starts with mcp://)
if agent.url.starts_with("mcp://") {
debug!(
"Agent {} is an MCP agent, fetching tool description from: {}",
agent.id, agent.url
);
// Fetch description from MCP endpoint
match self
.mcp_client
.fetch_tool_description(&agent.url, agent.tool.as_deref())
.await
{
Ok(mcp_description) => {
if !mcp_description.is_empty() {
debug!(
"Fetched MCP description for agent {}: {}",
agent.id, mcp_description
);
mcp_description
} else {
warn!(
"MCP tool description is empty for agent {}, using config description",
agent.id
);
agent_chain.description.clone().unwrap_or_default()
}
}
Err(e) => {
warn!(
"Failed to fetch MCP description for agent {}: {}, using config description",
agent.id, e
);
agent_chain.description.clone().unwrap_or_default()
}
}
} else {
// Not an MCP agent, use description from config
agent_chain.description.clone().unwrap_or_default()
}
} else {
// Agent not found in map, use description from config
agent_chain.description.clone().unwrap_or_default()
};
preferences.push(ModelUsagePreference {
model: agent_chain.id.clone(),
routing_preferences: vec![RoutingPreference {
name: agent.id.clone(),
description: agent.description.as_ref().unwrap_or(&String::new()).clone(),
name: agent_chain.id.clone(),
description,
}],
})
.collect()
});
}
preferences
}
}
@ -185,6 +250,7 @@ mod tests {
id: name.to_string(),
kind: Some("test".to_string()),
url: "http://localhost:8080".to_string(),
tool: None,
}
}
@ -240,8 +306,8 @@ mod tests {
assert!(agent_map.contains_key("agent2"));
}
#[test]
fn test_convert_agent_description_to_routing_preferences() {
#[tokio::test]
async fn test_convert_agent_description_to_routing_preferences() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
@ -250,7 +316,15 @@ mod tests {
create_test_agent("agent2", "Second agent description", false),
];
let preferences = selector.convert_agent_description_to_routing_preferences(&agents);
let agent_structs = vec![
create_test_agent_struct("agent1"),
create_test_agent_struct("agent2"),
];
let agent_map = selector.create_agent_map(&agent_structs);
let preferences = selector
.convert_agent_description_to_routing_preferences(&agents, &agent_map)
.await;
assert_eq!(preferences.len(), 2);
assert_eq!(preferences[0].model, "agent1");

View file

@ -65,6 +65,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = Arc::new(RwLock::new(arch_config.model_providers.clone()));
let agents_list = Arc::new(RwLock::new(arch_config.agents.clone()));
let agent_filters = Arc::new(RwLock::new(arch_config.agent_filters.clone()));
let listeners = Arc::new(RwLock::new(arch_config.listeners.clone()));
debug!(
@ -111,6 +112,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone();
let agents_list = agents_list.clone();
let agent_filters = agent_filters.clone();
let listeners = listeners.clone();
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
@ -119,6 +121,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone();
let model_aliases = Arc::clone(&model_aliases);
let agents_list = agents_list.clone();
let agent_filters = agent_filters.clone();
let listeners = listeners.clone();
async move {

View file

@ -0,0 +1,235 @@
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, warn};
/// MCP Tool definition from tools/list response
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpTool {
pub name: String,
pub description: Option<String>,
#[serde(rename = "inputSchema")]
pub input_schema: Option<serde_json::Value>,
}
/// Response from MCP tools/list endpoint
#[derive(Debug, Serialize, Deserialize)]
struct McpToolsListResponse {
tools: Vec<McpTool>,
}
/// Errors that can occur during MCP communication
#[derive(Debug, thiserror::Error)]
pub enum McpClientError {
#[error("HTTP request failed: {0}")]
HttpError(#[from] reqwest::Error),
#[error("Failed to parse response: {0}")]
ParseError(#[from] serde_json::Error),
#[error("Invalid MCP URL: {0}")]
InvalidUrl(String),
#[error("Tool not found: {0}")]
ToolNotFound(String),
}
/// Client for communicating with MCP (Model Context Protocol) servers
pub struct McpClient {
client: Client,
}
impl Default for McpClient {
fn default() -> Self {
Self::new()
}
}
impl McpClient {
pub fn new() -> Self {
Self {
client: Client::new(),
}
}
/// Parse MCP URL to extract host, port, and optional tool name
/// Supports formats:
/// - mcp://host:port
/// - mcp://host:port/tool_name
/// - mcp://host:port?tool=tool_name
fn parse_mcp_url(&self, mcp_url: &str) -> Result<(String, Option<String>), McpClientError> {
// Remove mcp:// prefix
let url_without_scheme = mcp_url
.strip_prefix("mcp://")
.ok_or_else(|| McpClientError::InvalidUrl(format!("URL must start with mcp://: {}", mcp_url)))?;
// Parse host:port and optional tool
let base_url: String;
let mut tool_name: Option<String> = None;
if let Some(query_start) = url_without_scheme.find('?') {
// Format: mcp://host:port?tool=tool_name
base_url = url_without_scheme[..query_start].to_string();
let query = &url_without_scheme[query_start + 1..];
// Parse query parameters
for param in query.split('&') {
if let Some((key, value)) = param.split_once('=') {
if key == "tool" {
tool_name = Some(value.to_string());
}
}
}
} else if let Some(path_start) = url_without_scheme.find('/') {
// Format: mcp://host:port/tool_name
base_url = url_without_scheme[..path_start].to_string();
tool_name = Some(url_without_scheme[path_start + 1..].to_string());
} else {
// Format: mcp://host:port
base_url = url_without_scheme.to_string();
}
Ok((format!("http://{}", base_url), tool_name))
}
/// Fetch list of tools from MCP server via SSE
pub async fn fetch_tools(&self, mcp_url: &str) -> Result<Vec<McpTool>, McpClientError> {
let (http_url, _) = self.parse_mcp_url(mcp_url)?;
let tools_list_url = format!("{}/sse/tools/list", http_url);
debug!("Fetching tools from MCP endpoint: {}", tools_list_url);
let response = self.client
.get(&tools_list_url)
.header("Accept", "text/event-stream")
.send()
.await?;
if !response.status().is_success() {
warn!(
"Failed to fetch tools from {}: status {}",
tools_list_url,
response.status()
);
return Ok(Vec::new());
}
let body = response.text().await?;
debug!("Received tools list response: {}", body);
// Parse SSE response - looking for data: lines
let mut tools = Vec::new();
for line in body.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if data.trim() == "[DONE]" {
break;
}
match serde_json::from_str::<McpToolsListResponse>(data) {
Ok(response) => {
tools.extend(response.tools);
}
Err(e) => {
debug!("Failed to parse tools list data: {}, line: {}", e, data);
}
}
}
}
debug!("Fetched {} tools from MCP server", tools.len());
Ok(tools)
}
/// Fetch specific tool description from MCP server
/// If tool_name is None, uses the tool name from the URL or returns the first tool
pub async fn fetch_tool_description(
&self,
mcp_url: &str,
tool_name_override: Option<&str>,
) -> Result<String, McpClientError> {
let (_, url_tool_name) = self.parse_mcp_url(mcp_url)?;
// Determine which tool to look for
let target_tool_name = tool_name_override
.or(url_tool_name.as_deref())
.ok_or_else(|| {
McpClientError::InvalidUrl(
"No tool name specified in URL or parameter".to_string()
)
})?;
debug!("Fetching description for tool: {}", target_tool_name);
let tools = self.fetch_tools(mcp_url).await?;
let tool = tools
.iter()
.find(|t| t.name == target_tool_name)
.ok_or_else(|| McpClientError::ToolNotFound(target_tool_name.to_string()))?;
Ok(tool.description.clone().unwrap_or_default())
}
/// Fetch all tools as a map of tool name to description
pub async fn fetch_tools_map(
&self,
mcp_url: &str,
) -> Result<HashMap<String, String>, McpClientError> {
let tools = self.fetch_tools(mcp_url).await?;
Ok(tools
.into_iter()
.map(|tool| {
(tool.name, tool.description.unwrap_or_default())
})
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_mcp_url_basic() {
let client = McpClient::new();
let (http_url, tool) = client.parse_mcp_url("mcp://localhost:10500").unwrap();
assert_eq!(http_url, "http://localhost:10500");
assert_eq!(tool, None);
}
#[test]
fn test_parse_mcp_url_with_path() {
let client = McpClient::new();
let (http_url, tool) = client.parse_mcp_url("mcp://localhost:10500/rewrite_query").unwrap();
assert_eq!(http_url, "http://localhost:10500");
assert_eq!(tool, Some("rewrite_query".to_string()));
}
#[test]
fn test_parse_mcp_url_with_query_param() {
let client = McpClient::new();
let (http_url, tool) = client.parse_mcp_url("mcp://localhost:10500?tool=rewrite_query").unwrap();
assert_eq!(http_url, "http://localhost:10500");
assert_eq!(tool, Some("rewrite_query".to_string()));
}
#[test]
fn test_parse_mcp_url_with_host_docker_internal() {
let client = McpClient::new();
let (http_url, tool) = client
.parse_mcp_url("mcp://host.docker.internal:10500/context_builder")
.unwrap();
assert_eq!(http_url, "http://host.docker.internal:10500");
assert_eq!(tool, Some("context_builder".to_string()));
}
#[test]
fn test_parse_mcp_url_invalid() {
let client = McpClient::new();
let result = client.parse_mcp_url("http://localhost:10500");
assert!(result.is_err());
}
}

View file

@ -1 +1,2 @@
pub mod mcp_client;
pub mod tracing;

View file

@ -23,6 +23,14 @@ pub struct Agent {
pub id: String,
pub kind: Option<String>,
pub url: String,
pub tool: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentFilter {
pub id: String,
pub url: String,
pub tool: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -57,6 +65,7 @@ pub struct Configuration {
pub mode: Option<GatewayMode>,
pub routing: Option<Routing>,
pub agents: Option<Vec<Agent>>,
pub agent_filters: Option<Vec<AgentFilter>>,
pub listeners: Vec<Listener>,
}