mirror of
https://github.com/katanemo/plano.git
synced 2026-06-26 15:39:40 +02:00
more changes
This commit is contained in:
parent
dcfc85ca74
commit
4140c1cde4
15 changed files with 883 additions and 109 deletions
|
|
@ -124,20 +124,25 @@ async fn handle_agent_chat(
|
|||
.find(|(key, _)| key.as_str() == "traceparent")
|
||||
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
|
||||
|
||||
// Select appropriate agent using arch router llm model
|
||||
let selected_agent = agent_selector
|
||||
.select_agent(&chat_completions_request.messages, &listener, trace_parent)
|
||||
.await?;
|
||||
|
||||
debug!("Processing agent pipeline: {}", selected_agent.id);
|
||||
|
||||
// Create agent map for pipeline processing
|
||||
// Create agent map for pipeline processing and agent selection
|
||||
let agent_map = {
|
||||
let agents = agents_list.read().await;
|
||||
let agents = agents.as_ref().unwrap();
|
||||
agent_selector.create_agent_map(agents)
|
||||
};
|
||||
|
||||
// Select appropriate agent using arch router llm model
|
||||
let selected_agent = agent_selector
|
||||
.select_agent(
|
||||
&chat_completions_request.messages,
|
||||
&listener,
|
||||
trace_parent,
|
||||
&agent_map,
|
||||
)
|
||||
.await?;
|
||||
|
||||
debug!("Processing agent pipeline: {}", selected_agent.id);
|
||||
|
||||
// Process the filter chain
|
||||
let processed_messages = pipeline_processor
|
||||
.process_filter_chain(
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ use hermesllm::apis::openai::Message;
|
|||
use tracing::{debug, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::utils::mcp_client::McpClient;
|
||||
|
||||
/// Errors that can occur during agent selection
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
|
@ -20,16 +21,22 @@ pub enum AgentSelectionError {
|
|||
RoutingError(String),
|
||||
#[error("Default agent not found for listener: {0}")]
|
||||
DefaultAgentNotFound(String),
|
||||
#[error("MCP client error: {0}")]
|
||||
McpError(String),
|
||||
}
|
||||
|
||||
/// Service for selecting agents based on routing preferences and listener configuration
|
||||
pub struct AgentSelector {
|
||||
router_service: Arc<RouterService>,
|
||||
mcp_client: McpClient,
|
||||
}
|
||||
|
||||
impl AgentSelector {
|
||||
pub fn new(router_service: Arc<RouterService>) -> Self {
|
||||
Self { router_service }
|
||||
Self {
|
||||
router_service,
|
||||
mcp_client: McpClient::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Find listener by name from the request headers
|
||||
|
|
@ -65,6 +72,7 @@ impl AgentSelector {
|
|||
messages: &[Message],
|
||||
listener: &Listener,
|
||||
trace_parent: Option<String>,
|
||||
agent_map: &HashMap<String, Agent>,
|
||||
) -> Result<AgentFilterChain, AgentSelectionError> {
|
||||
let agents = listener
|
||||
.agents
|
||||
|
|
@ -77,7 +85,9 @@ impl AgentSelector {
|
|||
return Ok(agents[0].clone());
|
||||
}
|
||||
|
||||
let usage_preferences = self.convert_agent_description_to_routing_preferences(agents);
|
||||
let usage_preferences = self
|
||||
.convert_agent_description_to_routing_preferences(agents, agent_map)
|
||||
.await;
|
||||
debug!(
|
||||
"Agents usage preferences for agent routing str: {}",
|
||||
serde_json::to_string(&usage_preferences).unwrap_or_default()
|
||||
|
|
@ -131,20 +141,75 @@ impl AgentSelector {
|
|||
}
|
||||
|
||||
/// Convert agent descriptions to routing preferences
|
||||
fn convert_agent_description_to_routing_preferences(
|
||||
/// For agents with MCP URLs, fetches the tool description from the MCP server
|
||||
async fn convert_agent_description_to_routing_preferences(
|
||||
&self,
|
||||
agents: &[AgentFilterChain],
|
||||
agent_map: &HashMap<String, Agent>,
|
||||
) -> Vec<ModelUsagePreference> {
|
||||
agents
|
||||
.iter()
|
||||
.map(|agent| ModelUsagePreference {
|
||||
model: agent.id.clone(),
|
||||
let mut preferences = Vec::new();
|
||||
|
||||
for agent_chain in agents {
|
||||
// Get the actual agent from the agent_map
|
||||
let agent = agent_map.get(&agent_chain.id);
|
||||
|
||||
// Determine the description to use
|
||||
let description = if let Some(agent) = agent {
|
||||
// Check if this is an MCP agent (URL starts with mcp://)
|
||||
if agent.url.starts_with("mcp://") {
|
||||
debug!(
|
||||
"Agent {} is an MCP agent, fetching tool description from: {}",
|
||||
agent.id, agent.url
|
||||
);
|
||||
|
||||
// Fetch description from MCP endpoint
|
||||
match self
|
||||
.mcp_client
|
||||
.fetch_tool_description(&agent.url, agent.tool.as_deref())
|
||||
.await
|
||||
{
|
||||
Ok(mcp_description) => {
|
||||
if !mcp_description.is_empty() {
|
||||
debug!(
|
||||
"Fetched MCP description for agent {}: {}",
|
||||
agent.id, mcp_description
|
||||
);
|
||||
mcp_description
|
||||
} else {
|
||||
warn!(
|
||||
"MCP tool description is empty for agent {}, using config description",
|
||||
agent.id
|
||||
);
|
||||
agent_chain.description.clone().unwrap_or_default()
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to fetch MCP description for agent {}: {}, using config description",
|
||||
agent.id, e
|
||||
);
|
||||
agent_chain.description.clone().unwrap_or_default()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Not an MCP agent, use description from config
|
||||
agent_chain.description.clone().unwrap_or_default()
|
||||
}
|
||||
} else {
|
||||
// Agent not found in map, use description from config
|
||||
agent_chain.description.clone().unwrap_or_default()
|
||||
};
|
||||
|
||||
preferences.push(ModelUsagePreference {
|
||||
model: agent_chain.id.clone(),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
name: agent.id.clone(),
|
||||
description: agent.description.as_ref().unwrap_or(&String::new()).clone(),
|
||||
name: agent_chain.id.clone(),
|
||||
description,
|
||||
}],
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
}
|
||||
|
||||
preferences
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -185,6 +250,7 @@ mod tests {
|
|||
id: name.to_string(),
|
||||
kind: Some("test".to_string()),
|
||||
url: "http://localhost:8080".to_string(),
|
||||
tool: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -240,8 +306,8 @@ mod tests {
|
|||
assert!(agent_map.contains_key("agent2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_agent_description_to_routing_preferences() {
|
||||
#[tokio::test]
|
||||
async fn test_convert_agent_description_to_routing_preferences() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
|
||||
|
|
@ -250,7 +316,15 @@ mod tests {
|
|||
create_test_agent("agent2", "Second agent description", false),
|
||||
];
|
||||
|
||||
let preferences = selector.convert_agent_description_to_routing_preferences(&agents);
|
||||
let agent_structs = vec![
|
||||
create_test_agent_struct("agent1"),
|
||||
create_test_agent_struct("agent2"),
|
||||
];
|
||||
let agent_map = selector.create_agent_map(&agent_structs);
|
||||
|
||||
let preferences = selector
|
||||
.convert_agent_description_to_routing_preferences(&agents, &agent_map)
|
||||
.await;
|
||||
|
||||
assert_eq!(preferences.len(), 2);
|
||||
assert_eq!(preferences[0].model, "agent1");
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let llm_providers = Arc::new(RwLock::new(arch_config.model_providers.clone()));
|
||||
let agents_list = Arc::new(RwLock::new(arch_config.agents.clone()));
|
||||
let agent_filters = Arc::new(RwLock::new(arch_config.agent_filters.clone()));
|
||||
let listeners = Arc::new(RwLock::new(arch_config.listeners.clone()));
|
||||
|
||||
debug!(
|
||||
|
|
@ -111,6 +112,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let llm_providers = llm_providers.clone();
|
||||
let agents_list = agents_list.clone();
|
||||
let agent_filters = agent_filters.clone();
|
||||
let listeners = listeners.clone();
|
||||
let service = service_fn(move |req| {
|
||||
let router_service = Arc::clone(&router_service);
|
||||
|
|
@ -119,6 +121,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let llm_providers = llm_providers.clone();
|
||||
let model_aliases = Arc::clone(&model_aliases);
|
||||
let agents_list = agents_list.clone();
|
||||
let agent_filters = agent_filters.clone();
|
||||
let listeners = listeners.clone();
|
||||
|
||||
async move {
|
||||
|
|
|
|||
235
crates/brightstaff/src/utils/mcp_client.rs
Normal file
235
crates/brightstaff/src/utils/mcp_client.rs
Normal file
|
|
@ -0,0 +1,235 @@
|
|||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
/// MCP Tool definition from tools/list response
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpTool {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
#[serde(rename = "inputSchema")]
|
||||
pub input_schema: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Response from MCP tools/list endpoint
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct McpToolsListResponse {
|
||||
tools: Vec<McpTool>,
|
||||
}
|
||||
|
||||
/// Errors that can occur during MCP communication
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum McpClientError {
|
||||
#[error("HTTP request failed: {0}")]
|
||||
HttpError(#[from] reqwest::Error),
|
||||
#[error("Failed to parse response: {0}")]
|
||||
ParseError(#[from] serde_json::Error),
|
||||
#[error("Invalid MCP URL: {0}")]
|
||||
InvalidUrl(String),
|
||||
#[error("Tool not found: {0}")]
|
||||
ToolNotFound(String),
|
||||
}
|
||||
|
||||
/// Client for communicating with MCP (Model Context Protocol) servers
|
||||
pub struct McpClient {
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl Default for McpClient {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl McpClient {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse MCP URL to extract host, port, and optional tool name
|
||||
/// Supports formats:
|
||||
/// - mcp://host:port
|
||||
/// - mcp://host:port/tool_name
|
||||
/// - mcp://host:port?tool=tool_name
|
||||
fn parse_mcp_url(&self, mcp_url: &str) -> Result<(String, Option<String>), McpClientError> {
|
||||
// Remove mcp:// prefix
|
||||
let url_without_scheme = mcp_url
|
||||
.strip_prefix("mcp://")
|
||||
.ok_or_else(|| McpClientError::InvalidUrl(format!("URL must start with mcp://: {}", mcp_url)))?;
|
||||
|
||||
// Parse host:port and optional tool
|
||||
let base_url: String;
|
||||
let mut tool_name: Option<String> = None;
|
||||
|
||||
if let Some(query_start) = url_without_scheme.find('?') {
|
||||
// Format: mcp://host:port?tool=tool_name
|
||||
base_url = url_without_scheme[..query_start].to_string();
|
||||
let query = &url_without_scheme[query_start + 1..];
|
||||
|
||||
// Parse query parameters
|
||||
for param in query.split('&') {
|
||||
if let Some((key, value)) = param.split_once('=') {
|
||||
if key == "tool" {
|
||||
tool_name = Some(value.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if let Some(path_start) = url_without_scheme.find('/') {
|
||||
// Format: mcp://host:port/tool_name
|
||||
base_url = url_without_scheme[..path_start].to_string();
|
||||
tool_name = Some(url_without_scheme[path_start + 1..].to_string());
|
||||
} else {
|
||||
// Format: mcp://host:port
|
||||
base_url = url_without_scheme.to_string();
|
||||
}
|
||||
|
||||
Ok((format!("http://{}", base_url), tool_name))
|
||||
}
|
||||
|
||||
/// Fetch list of tools from MCP server via SSE
|
||||
pub async fn fetch_tools(&self, mcp_url: &str) -> Result<Vec<McpTool>, McpClientError> {
|
||||
let (http_url, _) = self.parse_mcp_url(mcp_url)?;
|
||||
let tools_list_url = format!("{}/sse/tools/list", http_url);
|
||||
|
||||
debug!("Fetching tools from MCP endpoint: {}", tools_list_url);
|
||||
|
||||
let response = self.client
|
||||
.get(&tools_list_url)
|
||||
.header("Accept", "text/event-stream")
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
warn!(
|
||||
"Failed to fetch tools from {}: status {}",
|
||||
tools_list_url,
|
||||
response.status()
|
||||
);
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let body = response.text().await?;
|
||||
debug!("Received tools list response: {}", body);
|
||||
|
||||
// Parse SSE response - looking for data: lines
|
||||
let mut tools = Vec::new();
|
||||
for line in body.lines() {
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if data.trim() == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
match serde_json::from_str::<McpToolsListResponse>(data) {
|
||||
Ok(response) => {
|
||||
tools.extend(response.tools);
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to parse tools list data: {}, line: {}", e, data);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Fetched {} tools from MCP server", tools.len());
|
||||
Ok(tools)
|
||||
}
|
||||
|
||||
/// Fetch specific tool description from MCP server
|
||||
/// If tool_name is None, uses the tool name from the URL or returns the first tool
|
||||
pub async fn fetch_tool_description(
|
||||
&self,
|
||||
mcp_url: &str,
|
||||
tool_name_override: Option<&str>,
|
||||
) -> Result<String, McpClientError> {
|
||||
let (_, url_tool_name) = self.parse_mcp_url(mcp_url)?;
|
||||
|
||||
// Determine which tool to look for
|
||||
let target_tool_name = tool_name_override
|
||||
.or(url_tool_name.as_deref())
|
||||
.ok_or_else(|| {
|
||||
McpClientError::InvalidUrl(
|
||||
"No tool name specified in URL or parameter".to_string()
|
||||
)
|
||||
})?;
|
||||
|
||||
debug!("Fetching description for tool: {}", target_tool_name);
|
||||
|
||||
let tools = self.fetch_tools(mcp_url).await?;
|
||||
|
||||
let tool = tools
|
||||
.iter()
|
||||
.find(|t| t.name == target_tool_name)
|
||||
.ok_or_else(|| McpClientError::ToolNotFound(target_tool_name.to_string()))?;
|
||||
|
||||
Ok(tool.description.clone().unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Fetch all tools as a map of tool name to description
|
||||
pub async fn fetch_tools_map(
|
||||
&self,
|
||||
mcp_url: &str,
|
||||
) -> Result<HashMap<String, String>, McpClientError> {
|
||||
let tools = self.fetch_tools(mcp_url).await?;
|
||||
|
||||
Ok(tools
|
||||
.into_iter()
|
||||
.map(|tool| {
|
||||
(tool.name, tool.description.unwrap_or_default())
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_mcp_url_basic() {
|
||||
let client = McpClient::new();
|
||||
|
||||
let (http_url, tool) = client.parse_mcp_url("mcp://localhost:10500").unwrap();
|
||||
assert_eq!(http_url, "http://localhost:10500");
|
||||
assert_eq!(tool, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mcp_url_with_path() {
|
||||
let client = McpClient::new();
|
||||
|
||||
let (http_url, tool) = client.parse_mcp_url("mcp://localhost:10500/rewrite_query").unwrap();
|
||||
assert_eq!(http_url, "http://localhost:10500");
|
||||
assert_eq!(tool, Some("rewrite_query".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mcp_url_with_query_param() {
|
||||
let client = McpClient::new();
|
||||
|
||||
let (http_url, tool) = client.parse_mcp_url("mcp://localhost:10500?tool=rewrite_query").unwrap();
|
||||
assert_eq!(http_url, "http://localhost:10500");
|
||||
assert_eq!(tool, Some("rewrite_query".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mcp_url_with_host_docker_internal() {
|
||||
let client = McpClient::new();
|
||||
|
||||
let (http_url, tool) = client
|
||||
.parse_mcp_url("mcp://host.docker.internal:10500/context_builder")
|
||||
.unwrap();
|
||||
assert_eq!(http_url, "http://host.docker.internal:10500");
|
||||
assert_eq!(tool, Some("context_builder".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mcp_url_invalid() {
|
||||
let client = McpClient::new();
|
||||
|
||||
let result = client.parse_mcp_url("http://localhost:10500");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
|
@ -1 +1,2 @@
|
|||
pub mod mcp_client;
|
||||
pub mod tracing;
|
||||
|
|
|
|||
|
|
@ -23,6 +23,14 @@ pub struct Agent {
|
|||
pub id: String,
|
||||
pub kind: Option<String>,
|
||||
pub url: String,
|
||||
pub tool: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AgentFilter {
|
||||
pub id: String,
|
||||
pub url: String,
|
||||
pub tool: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -57,6 +65,7 @@ pub struct Configuration {
|
|||
pub mode: Option<GatewayMode>,
|
||||
pub routing: Option<Routing>,
|
||||
pub agents: Option<Vec<Agent>>,
|
||||
pub agent_filters: Option<Vec<AgentFilter>>,
|
||||
pub listeners: Vec<Listener>,
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue