mirror of
https://github.com/katanemo/plano.git
synced 2026-06-20 15:28:07 +02:00
Merge branch 'main' into salmanap/add-support-for-bedrock-llms
This commit is contained in:
commit
2b54b8833e
29 changed files with 2741 additions and 391 deletions
50
crates/Cargo.lock
generated
50
crates/Cargo.lock
generated
|
|
@ -68,6 +68,16 @@ version = "1.0.98"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487"
|
||||
|
||||
[[package]]
|
||||
name = "assert-json-diff"
|
||||
version = "2.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.88"
|
||||
|
|
@ -202,6 +212,7 @@ dependencies = [
|
|||
"http-body-util",
|
||||
"hyper 1.6.0",
|
||||
"hyper-util",
|
||||
"mockito",
|
||||
"opentelemetry",
|
||||
"opentelemetry-http",
|
||||
"opentelemetry-otlp",
|
||||
|
|
@ -283,6 +294,15 @@ dependencies = [
|
|||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colored"
|
||||
version = "3.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e"
|
||||
dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "common"
|
||||
version = "0.1.0"
|
||||
|
|
@ -1318,6 +1338,30 @@ dependencies = [
|
|||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mockito"
|
||||
version = "1.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7760e0e418d9b7e5777c0374009ca4c93861b9066f18cb334a20ce50ab63aa48"
|
||||
dependencies = [
|
||||
"assert-json-diff",
|
||||
"bytes",
|
||||
"colored",
|
||||
"futures-util",
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"hyper 1.6.0",
|
||||
"hyper-util",
|
||||
"log",
|
||||
"rand 0.9.1",
|
||||
"regex",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"similar",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.14"
|
||||
|
|
@ -2269,6 +2313,12 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "similar"
|
||||
version = "2.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa"
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.9"
|
||||
|
|
|
|||
|
|
@ -28,8 +28,14 @@ serde_with = "3.13.0"
|
|||
serde_yaml = "0.9.34"
|
||||
thiserror = "2.0.12"
|
||||
tokio = { version = "1.44.2", features = ["full"] }
|
||||
tokio-stream = "0.1.17"
|
||||
tokio-stream = "0.1"
|
||||
time = { version = "0.3", features = ["formatting", "macros"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
[dev-dependencies]
|
||||
mockito = "1.0"
|
||||
tokio-stream = "0.1.17"
|
||||
tracing = "0.1.41"
|
||||
tracing-opentelemetry = "0.30.0"
|
||||
tracing-subscriber = { version = "0.3.19", features = ["env-filter", "fmt", "time"] }
|
||||
|
|
|
|||
172
crates/brightstaff/src/handlers/agent_chat_completions.rs
Normal file
172
crates/brightstaff/src/handlers/agent_chat_completions.rs
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::{Request, Response};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::agent_selector::{AgentSelectionError, AgentSelector};
|
||||
use super::pipeline_processor::{PipelineError, PipelineProcessor};
|
||||
use super::response_handler::ResponseHandler;
|
||||
use crate::router::llm_router::RouterService;
|
||||
|
||||
/// Main errors for agent chat completions
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AgentFilterChainError {
|
||||
#[error("Agent selection error: {0}")]
|
||||
Selection(#[from] AgentSelectionError),
|
||||
#[error("Pipeline processing error: {0}")]
|
||||
Pipeline(#[from] PipelineError),
|
||||
#[error("Response handling error: {0}")]
|
||||
Response(#[from] super::response_handler::ResponseError),
|
||||
#[error("Request parsing error: {0}")]
|
||||
RequestParsing(#[from] serde_json::Error),
|
||||
#[error("HTTP error: {0}")]
|
||||
Http(#[from] hyper::Error),
|
||||
}
|
||||
|
||||
pub async fn agent_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
_: String,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
match handle_agent_chat(request, router_service, agents_list, listeners).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
// Print detailed error information with full error chain
|
||||
let mut error_chain = Vec::new();
|
||||
let mut current_error: &dyn std::error::Error = &err;
|
||||
|
||||
// Collect the full error chain
|
||||
loop {
|
||||
error_chain.push(current_error.to_string());
|
||||
match current_error.source() {
|
||||
Some(source) => current_error = source,
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
|
||||
// Log the complete error chain
|
||||
warn!("Agent chat error chain: {:#?}", error_chain);
|
||||
warn!("Root error: {:?}", err);
|
||||
|
||||
// Create structured error response as JSON
|
||||
let error_json = serde_json::json!({
|
||||
"error": {
|
||||
"type": "AgentFilterChainError",
|
||||
"message": err.to_string(),
|
||||
"error_chain": error_chain,
|
||||
"debug_info": format!("{:?}", err)
|
||||
}
|
||||
});
|
||||
|
||||
// Log the error for debugging
|
||||
info!("Structured error info: {}", error_json);
|
||||
|
||||
// Return JSON error response
|
||||
Ok(ResponseHandler::create_json_error_response(&error_json))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_agent_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
|
||||
// Initialize services
|
||||
let agent_selector = AgentSelector::new(router_service);
|
||||
let pipeline_processor = PipelineProcessor::default();
|
||||
let response_handler = ResponseHandler::new();
|
||||
|
||||
// Extract listener name from headers
|
||||
let listener_name = request
|
||||
.headers()
|
||||
.get("x-arch-agent-listener-name")
|
||||
.and_then(|name| name.to_str().ok());
|
||||
|
||||
// Find the appropriate listener
|
||||
let listener = {
|
||||
let listeners = listeners.read().await;
|
||||
agent_selector
|
||||
.find_listener(listener_name, &listeners)
|
||||
.await?
|
||||
};
|
||||
|
||||
info!("Handling request for listener: {}", listener.name);
|
||||
|
||||
// Parse request body
|
||||
let request_headers = request.headers().clone();
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!(
|
||||
"Received request body (raw utf8): {}",
|
||||
String::from_utf8_lossy(&chat_request_bytes)
|
||||
);
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest =
|
||||
serde_json::from_slice(&chat_request_bytes).map_err(|err| {
|
||||
warn!(
|
||||
"Failed to parse request body as ChatCompletionsRequest: {}",
|
||||
err
|
||||
);
|
||||
AgentFilterChainError::RequestParsing(err)
|
||||
})?;
|
||||
|
||||
// Extract trace parent for routing
|
||||
let trace_parent = request_headers
|
||||
.iter()
|
||||
.find(|(key, _)| key.as_str() == "traceparent")
|
||||
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
|
||||
|
||||
// Select appropriate agent using arch router llm model
|
||||
let selected_agent = agent_selector
|
||||
.select_agent(&chat_completions_request.messages, &listener, trace_parent)
|
||||
.await?;
|
||||
|
||||
debug!("Processing agent pipeline: {}", selected_agent.id);
|
||||
|
||||
// Create agent map for pipeline processing
|
||||
let agent_map = {
|
||||
let agents = agents_list.read().await;
|
||||
let agents = agents.as_ref().unwrap();
|
||||
agent_selector.create_agent_map(agents)
|
||||
};
|
||||
|
||||
// Process the filter chain
|
||||
let processed_messages = pipeline_processor
|
||||
.process_filter_chain(
|
||||
&chat_completions_request,
|
||||
&selected_agent,
|
||||
&agent_map,
|
||||
&request_headers,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Get terminal agent and send final response
|
||||
let terminal_agent_name = selected_agent.id;
|
||||
let terminal_agent = agent_map.get(&terminal_agent_name).unwrap();
|
||||
|
||||
debug!("Processing terminal agent: {}", terminal_agent_name);
|
||||
debug!("Terminal agent details: {:?}", terminal_agent);
|
||||
|
||||
let llm_response = pipeline_processor
|
||||
.invoke_upstream_agent(
|
||||
&processed_messages,
|
||||
&chat_completions_request,
|
||||
terminal_agent,
|
||||
&request_headers,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Create streaming response
|
||||
response_handler
|
||||
.create_streaming_response(llm_response)
|
||||
.await
|
||||
.map_err(AgentFilterChainError::from)
|
||||
}
|
||||
296
crates/brightstaff/src/handlers/agent_selector.rs
Normal file
296
crates/brightstaff/src/handlers/agent_selector.rs
Normal file
|
|
@ -0,0 +1,296 @@
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::configuration::{
|
||||
Agent, AgentFilterChain, Listener, ModelUsagePreference, RoutingPreference,
|
||||
};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
|
||||
/// Errors that can occur during agent selection
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AgentSelectionError {
|
||||
#[error("Listener not found for name: {0}")]
|
||||
ListenerNotFound(String),
|
||||
#[error("No agents configured for listener: {0}")]
|
||||
NoAgentsConfigured(String),
|
||||
#[error("Routing service error: {0}")]
|
||||
RoutingError(String),
|
||||
#[error("Default agent not found for listener: {0}")]
|
||||
DefaultAgentNotFound(String),
|
||||
}
|
||||
|
||||
/// Service for selecting agents based on routing preferences and listener configuration
|
||||
pub struct AgentSelector {
|
||||
router_service: Arc<RouterService>,
|
||||
}
|
||||
|
||||
impl AgentSelector {
|
||||
pub fn new(router_service: Arc<RouterService>) -> Self {
|
||||
Self { router_service }
|
||||
}
|
||||
|
||||
/// Find listener by name from the request headers
|
||||
pub async fn find_listener(
|
||||
&self,
|
||||
listener_name: Option<&str>,
|
||||
listeners: &[common::configuration::Listener],
|
||||
) -> Result<Listener, AgentSelectionError> {
|
||||
let listener = listeners
|
||||
.iter()
|
||||
.find(|l| listener_name.map(|name| l.name == name).unwrap_or(false))
|
||||
.cloned()
|
||||
.ok_or_else(|| {
|
||||
AgentSelectionError::ListenerNotFound(
|
||||
listener_name.unwrap_or("unknown").to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(listener)
|
||||
}
|
||||
|
||||
/// Create agent name to agent mapping for efficient lookup
|
||||
pub fn create_agent_map(&self, agents: &[Agent]) -> HashMap<String, Agent> {
|
||||
agents
|
||||
.iter()
|
||||
.map(|agent| (agent.id.clone(), agent.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Select appropriate agent based on routing preferences
|
||||
pub async fn select_agent(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
listener: &Listener,
|
||||
trace_parent: Option<String>,
|
||||
) -> Result<AgentFilterChain, AgentSelectionError> {
|
||||
let agents = listener
|
||||
.agents
|
||||
.as_ref()
|
||||
.ok_or_else(|| AgentSelectionError::NoAgentsConfigured(listener.name.clone()))?;
|
||||
|
||||
// If only one agent, skip routing
|
||||
if agents.len() == 1 {
|
||||
debug!("Only one agent available, skipping routing");
|
||||
return Ok(agents[0].clone());
|
||||
}
|
||||
|
||||
let usage_preferences = self.convert_agent_description_to_routing_preferences(agents);
|
||||
debug!(
|
||||
"Agents usage preferences for agent routing str: {}",
|
||||
serde_json::to_string(&usage_preferences).unwrap_or_default()
|
||||
);
|
||||
|
||||
match self
|
||||
.router_service
|
||||
.determine_route(messages, trace_parent, Some(usage_preferences))
|
||||
.await
|
||||
{
|
||||
Ok(Some((_, agent_name))) => {
|
||||
debug!("Determined agent: {}", agent_name);
|
||||
let selected_agent = agents
|
||||
.iter()
|
||||
.find(|a| a.id == agent_name)
|
||||
.cloned()
|
||||
.ok_or_else(|| {
|
||||
AgentSelectionError::RoutingError(format!(
|
||||
"Selected agent '{}' not found in listener agents",
|
||||
agent_name
|
||||
))
|
||||
})?;
|
||||
Ok(selected_agent)
|
||||
}
|
||||
Ok(None) => {
|
||||
debug!("No agent determined using routing preferences, using default agent");
|
||||
self.get_default_agent(agents, &listener.name)
|
||||
}
|
||||
Err(err) => Err(AgentSelectionError::RoutingError(err.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the default agent or the first agent if no default is specified
|
||||
fn get_default_agent(
|
||||
&self,
|
||||
agents: &[AgentFilterChain],
|
||||
listener_name: &str,
|
||||
) -> Result<AgentFilterChain, AgentSelectionError> {
|
||||
agents
|
||||
.iter()
|
||||
.find(|a| a.default.unwrap_or(false))
|
||||
.cloned()
|
||||
.or_else(|| {
|
||||
warn!(
|
||||
"No default agent found, routing request to first agent: {}",
|
||||
agents[0].id
|
||||
);
|
||||
Some(agents[0].clone())
|
||||
})
|
||||
.ok_or_else(|| AgentSelectionError::DefaultAgentNotFound(listener_name.to_string()))
|
||||
}
|
||||
|
||||
/// Convert agent descriptions to routing preferences
|
||||
fn convert_agent_description_to_routing_preferences(
|
||||
&self,
|
||||
agents: &[AgentFilterChain],
|
||||
) -> Vec<ModelUsagePreference> {
|
||||
agents
|
||||
.iter()
|
||||
.map(|agent| ModelUsagePreference {
|
||||
model: agent.id.clone(),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
name: agent.id.clone(),
|
||||
description: agent.description.as_ref().unwrap_or(&String::new()).clone(),
|
||||
}],
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::configuration::{AgentFilterChain, Listener};
|
||||
|
||||
fn create_test_router_service() -> Arc<RouterService> {
|
||||
Arc::new(RouterService::new(
|
||||
vec![], // empty providers for testing
|
||||
"http://localhost:8080".to_string(),
|
||||
"test-model".to_string(),
|
||||
"test-provider".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn create_test_agent(name: &str, description: &str, is_default: bool) -> AgentFilterChain {
|
||||
AgentFilterChain {
|
||||
id: name.to_string(),
|
||||
description: Some(description.to_string()),
|
||||
default: Some(is_default),
|
||||
filter_chain: vec![name.to_string()],
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_listener(name: &str, agents: Vec<AgentFilterChain>) -> Listener {
|
||||
Listener {
|
||||
name: name.to_string(),
|
||||
agents: Some(agents),
|
||||
port: 8080,
|
||||
router: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_agent_struct(name: &str) -> Agent {
|
||||
Agent {
|
||||
id: name.to_string(),
|
||||
kind: Some("test".to_string()),
|
||||
url: "http://localhost:8080".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_listener_success() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
|
||||
let listener1 = create_test_listener("test-listener", vec![]);
|
||||
let listener2 = create_test_listener("other-listener", vec![]);
|
||||
let listeners = vec![listener1.clone(), listener2];
|
||||
|
||||
let result = selector
|
||||
.find_listener(Some("test-listener"), &listeners)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().name, "test-listener");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_listener_not_found() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
|
||||
let listeners = vec![create_test_listener("other-listener", vec![])];
|
||||
|
||||
let result = selector
|
||||
.find_listener(Some("nonexistent"), &listeners)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
matches!(
|
||||
result.unwrap_err(),
|
||||
AgentSelectionError::ListenerNotFound(_)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_agent_map() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
|
||||
let agents = vec![
|
||||
create_test_agent_struct("agent1"),
|
||||
create_test_agent_struct("agent2"),
|
||||
];
|
||||
|
||||
let agent_map = selector.create_agent_map(&agents);
|
||||
|
||||
assert_eq!(agent_map.len(), 2);
|
||||
assert!(agent_map.contains_key("agent1"));
|
||||
assert!(agent_map.contains_key("agent2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_agent_description_to_routing_preferences() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
|
||||
let agents = vec![
|
||||
create_test_agent("agent1", "First agent description", true),
|
||||
create_test_agent("agent2", "Second agent description", false),
|
||||
];
|
||||
|
||||
let preferences = selector.convert_agent_description_to_routing_preferences(&agents);
|
||||
|
||||
assert_eq!(preferences.len(), 2);
|
||||
assert_eq!(preferences[0].model, "agent1");
|
||||
assert_eq!(preferences[0].routing_preferences[0].name, "agent1");
|
||||
assert_eq!(
|
||||
preferences[0].routing_preferences[0].description,
|
||||
"First agent description"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_default_agent() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
|
||||
let agents = vec![
|
||||
create_test_agent("agent1", "First agent", false),
|
||||
create_test_agent("agent2", "Default agent", true),
|
||||
create_test_agent("agent3", "Third agent", false),
|
||||
];
|
||||
|
||||
let result = selector.get_default_agent(&agents, "test-listener");
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().id, "agent2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_default_agent_fallback_to_first() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
|
||||
let agents = vec![
|
||||
create_test_agent("agent1", "First agent", false),
|
||||
create_test_agent("agent2", "Second agent", false),
|
||||
];
|
||||
|
||||
let result = selector.get_default_agent(&agents, "test-listener");
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().id, "agent1");
|
||||
}
|
||||
}
|
||||
155
crates/brightstaff/src/handlers/integration_tests.rs
Normal file
155
crates/brightstaff/src/handlers/integration_tests.rs
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
use hyper::header::HeaderMap;
|
||||
|
||||
use crate::handlers::agent_selector::{AgentSelectionError, AgentSelector};
|
||||
use crate::handlers::pipeline_processor::PipelineProcessor;
|
||||
use crate::handlers::response_handler::ResponseHandler;
|
||||
use crate::router::llm_router::RouterService;
|
||||
|
||||
/// Integration test that demonstrates the modular agent chat flow
|
||||
/// This test shows how the three main components work together:
|
||||
/// 1. AgentSelector - selects the appropriate agent based on routing
|
||||
/// 2. PipelineProcessor - executes the agent pipeline
|
||||
/// 3. ResponseHandler - handles response streaming
|
||||
#[cfg(test)]
|
||||
mod integration_tests {
|
||||
use super::*;
|
||||
use common::configuration::{Agent, AgentFilterChain, Listener};
|
||||
|
||||
fn create_test_router_service() -> Arc<RouterService> {
|
||||
Arc::new(RouterService::new(
|
||||
vec![], // empty providers for testing
|
||||
"http://localhost:8080".to_string(),
|
||||
"test-model".to_string(),
|
||||
"test-provider".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn create_test_message(role: Role, content: &str) -> Message {
|
||||
Message {
|
||||
role,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_modular_agent_chat_flow() {
|
||||
// Setup services
|
||||
let router_service = create_test_router_service();
|
||||
let agent_selector = AgentSelector::new(router_service);
|
||||
let pipeline_processor = PipelineProcessor::default();
|
||||
|
||||
// Create test data
|
||||
let agents = vec![
|
||||
Agent {
|
||||
id: "filter-agent".to_string(),
|
||||
kind: Some("filter".to_string()),
|
||||
url: "http://localhost:8081".to_string(),
|
||||
},
|
||||
Agent {
|
||||
id: "terminal-agent".to_string(),
|
||||
kind: Some("terminal".to_string()),
|
||||
url: "http://localhost:8082".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let agent_pipeline = AgentFilterChain {
|
||||
id: "terminal-agent".to_string(),
|
||||
filter_chain: vec!["filter-agent".to_string(), "terminal-agent".to_string()],
|
||||
description: Some("Test pipeline".to_string()),
|
||||
default: Some(true),
|
||||
};
|
||||
|
||||
let listener = Listener {
|
||||
name: "test-listener".to_string(),
|
||||
agents: Some(vec![agent_pipeline.clone()]),
|
||||
port: 8080,
|
||||
router: None,
|
||||
};
|
||||
|
||||
let listeners = vec![listener];
|
||||
let messages = vec![create_test_message(Role::User, "Hello world!")];
|
||||
|
||||
// Test 1: Agent Selection
|
||||
let selected_listener = agent_selector
|
||||
.find_listener(Some("test-listener"), &listeners)
|
||||
.await;
|
||||
|
||||
assert!(selected_listener.is_ok());
|
||||
let listener = selected_listener.unwrap();
|
||||
assert_eq!(listener.name, "test-listener");
|
||||
|
||||
// Test 2: Agent Map Creation
|
||||
let agent_map = agent_selector.create_agent_map(&agents);
|
||||
assert_eq!(agent_map.len(), 2);
|
||||
assert!(agent_map.contains_key("filter-agent"));
|
||||
assert!(agent_map.contains_key("terminal-agent"));
|
||||
|
||||
// Test 3: Pipeline Processing (empty filter chain for testing)
|
||||
let request = ChatCompletionsRequest {
|
||||
messages: messages.clone(),
|
||||
model: "test-model".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Create a pipeline with empty filter chain to avoid network calls
|
||||
let test_pipeline = AgentFilterChain {
|
||||
id: "terminal-agent".to_string(),
|
||||
filter_chain: vec![], // Empty filter chain - no network calls needed
|
||||
description: None,
|
||||
default: None,
|
||||
};
|
||||
|
||||
let headers = HeaderMap::new();
|
||||
let result = pipeline_processor
|
||||
.process_filter_chain(&request, &test_pipeline, &agent_map, &headers)
|
||||
.await;
|
||||
|
||||
println!("Pipeline processing result: {:?}", result);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let processed_messages = result.unwrap();
|
||||
// With empty filter chain, should return the original messages unchanged
|
||||
assert_eq!(processed_messages.len(), 1);
|
||||
if let MessageContent::Text(content) = &processed_messages[0].content {
|
||||
assert_eq!(content, "Hello world!");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
}
|
||||
|
||||
// Test 4: Error Response Creation
|
||||
let error_response = ResponseHandler::create_bad_request("Test error");
|
||||
assert_eq!(error_response.status(), hyper::StatusCode::BAD_REQUEST);
|
||||
|
||||
println!("✅ All modular components working correctly!");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_error_handling_flow() {
|
||||
let router_service = create_test_router_service();
|
||||
let agent_selector = AgentSelector::new(router_service);
|
||||
|
||||
// Test listener not found
|
||||
let result = agent_selector.find_listener(Some("nonexistent"), &[]).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
AgentSelectionError::ListenerNotFound(_)
|
||||
));
|
||||
|
||||
// Test error response creation
|
||||
let error_response = ResponseHandler::create_internal_error("Pipeline failed");
|
||||
assert_eq!(
|
||||
error_response.status(),
|
||||
hyper::StatusCode::INTERNAL_SERVER_ERROR
|
||||
);
|
||||
|
||||
println!("✅ Error handling working correctly!");
|
||||
}
|
||||
}
|
||||
|
|
@ -1,2 +1,9 @@
|
|||
pub mod agent_chat_completions;
|
||||
pub mod agent_selector;
|
||||
pub mod chat_completions;
|
||||
pub mod models;
|
||||
pub mod pipeline_processor;
|
||||
pub mod response_handler;
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration_tests;
|
||||
|
|
|
|||
228
crates/brightstaff/src/handlers/pipeline_processor.rs
Normal file
228
crates/brightstaff/src/handlers/pipeline_processor.rs
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use common::configuration::{Agent, AgentFilterChain};
|
||||
use common::consts::{ARCH_UPSTREAM_HOST_HEADER, ENVOY_RETRY_HEADER};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
|
||||
use hyper::header::HeaderMap;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
/// Errors that can occur during pipeline processing
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PipelineError {
|
||||
#[error("HTTP request failed: {0}")]
|
||||
RequestFailed(#[from] reqwest::Error),
|
||||
#[error("Failed to parse response: {0}")]
|
||||
ParseError(#[from] serde_json::Error),
|
||||
#[error("Agent '{0}' not found in agent map")]
|
||||
AgentNotFound(String),
|
||||
#[error("No choices in response from agent '{0}'")]
|
||||
NoChoicesInResponse(String),
|
||||
#[error("No content in response from agent '{0}'")]
|
||||
NoContentInResponse(String),
|
||||
}
|
||||
|
||||
/// Service for processing agent pipelines
|
||||
pub struct PipelineProcessor {
|
||||
client: reqwest::Client,
|
||||
url: String,
|
||||
}
|
||||
|
||||
impl Default for PipelineProcessor {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
url: "http://localhost:11000/v1/chat/completions".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PipelineProcessor {
|
||||
pub fn new(url: String) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
url,
|
||||
}
|
||||
}
|
||||
|
||||
/// Process the filter chain of agents (all except the terminal agent)
|
||||
pub async fn process_filter_chain(
|
||||
&self,
|
||||
initial_request: &ChatCompletionsRequest,
|
||||
agent_filter_chain: &AgentFilterChain,
|
||||
agent_map: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<Vec<Message>, PipelineError> {
|
||||
let mut chat_completions_history = initial_request.messages.clone();
|
||||
|
||||
for agent_name in &agent_filter_chain.filter_chain {
|
||||
debug!("Processing filter agent: {}", agent_name);
|
||||
|
||||
let agent = agent_map
|
||||
.get(agent_name)
|
||||
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
|
||||
|
||||
debug!("Agent details: {:?}", agent);
|
||||
|
||||
let response_content = self
|
||||
.send_agent_filter_chain_request(
|
||||
&chat_completions_history,
|
||||
initial_request,
|
||||
agent,
|
||||
request_headers,
|
||||
)
|
||||
.await?;
|
||||
|
||||
debug!("Received response from filter agent {}", agent_name);
|
||||
|
||||
// Parse the response content as new message history
|
||||
chat_completions_history =
|
||||
serde_json::from_str(&response_content).inspect_err(|err| {
|
||||
warn!(
|
||||
"Failed to parse response from agent {}, err: {}, response: {}",
|
||||
agent_name, err, response_content
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(chat_completions_history)
|
||||
}
|
||||
|
||||
/// Send request to a specific agent and return the response content
|
||||
async fn send_agent_filter_chain_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
original_request: &ChatCompletionsRequest,
|
||||
agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<String, PipelineError> {
|
||||
let mut request = original_request.clone();
|
||||
request.messages = messages.to_vec();
|
||||
|
||||
let request_body = serde_json::to_string(&request)?;
|
||||
debug!("Sending request to agent {}", agent.id);
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&agent.id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.url)
|
||||
.headers(agent_headers)
|
||||
.body(request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let response_bytes = response.bytes().await?;
|
||||
|
||||
// Parse the response as JSON to extract the content
|
||||
let response_json: serde_json::Value = serde_json::from_slice(&response_bytes)?;
|
||||
|
||||
let content = response_json
|
||||
.get("choices")
|
||||
.and_then(|choices| choices.as_array())
|
||||
.and_then(|choices| choices.first())
|
||||
.and_then(|choice| choice.get("message"))
|
||||
.and_then(|message| message.get("content"))
|
||||
.and_then(|content| content.as_str())
|
||||
.ok_or_else(|| PipelineError::NoContentInResponse(agent.id.clone()))?
|
||||
.to_string();
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
/// Send request to terminal agent and return the raw response for streaming
|
||||
pub async fn invoke_upstream_agent(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
original_request: &ChatCompletionsRequest,
|
||||
terminal_agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<reqwest::Response, PipelineError> {
|
||||
let mut request = original_request.clone();
|
||||
request.messages = messages.to_vec();
|
||||
|
||||
let request_body = serde_json::to_string(&request)?;
|
||||
debug!("Sending request to terminal agent {}", terminal_agent.id);
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&terminal_agent.id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(terminal_agent.id.clone()))?,
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.url)
|
||||
.headers(agent_headers)
|
||||
.body(request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use hermesllm::apis::openai::{Message, MessageContent, Role};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn create_test_message(role: Role, content: &str) -> Message {
|
||||
Message {
|
||||
role,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain {
|
||||
AgentFilterChain {
|
||||
id: "test-agent".to_string(),
|
||||
filter_chain: agents.iter().map(|s| s.to_string()).collect(),
|
||||
description: None,
|
||||
default: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_agent_not_found_error() {
|
||||
let processor = PipelineProcessor::default();
|
||||
let agent_map = HashMap::new();
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let initial_request = ChatCompletionsRequest {
|
||||
messages: vec![create_test_message(Role::User, "Hello")],
|
||||
model: "test-model".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
|
||||
|
||||
let result = processor
|
||||
.process_filter_chain(&initial_request, &pipeline, &agent_map, &request_headers)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
matches!(result.unwrap_err(), PipelineError::AgentNotFound(_));
|
||||
}
|
||||
}
|
||||
191
crates/brightstaff/src/handlers/response_handler.rs
Normal file
191
crates/brightstaff/src/handlers/response_handler.rs
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
use bytes::Bytes;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
use hyper::{Response, StatusCode};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::warn;
|
||||
|
||||
/// Errors that can occur during response handling
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ResponseError {
|
||||
#[error("Failed to create response: {0}")]
|
||||
ResponseCreationFailed(#[from] hyper::http::Error),
|
||||
#[error("Stream error: {0}")]
|
||||
StreamError(String),
|
||||
}
|
||||
|
||||
/// Service for handling HTTP responses and streaming
|
||||
pub struct ResponseHandler;
|
||||
|
||||
impl ResponseHandler {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
/// Create a full response body from bytes
|
||||
pub fn create_full_body<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||
Full::new(chunk.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// Create an error response with a given status code and message
|
||||
pub fn create_error_response(
|
||||
status: StatusCode,
|
||||
message: &str,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let mut response = Response::new(Self::create_full_body(message.to_string()));
|
||||
*response.status_mut() = status;
|
||||
response
|
||||
}
|
||||
|
||||
/// Create a bad request response
|
||||
pub fn create_bad_request(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
Self::create_error_response(StatusCode::BAD_REQUEST, message)
|
||||
}
|
||||
|
||||
/// Create an internal server error response
|
||||
pub fn create_internal_error(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
Self::create_error_response(StatusCode::INTERNAL_SERVER_ERROR, message)
|
||||
}
|
||||
|
||||
/// Create a JSON error response
|
||||
pub fn create_json_error_response(
|
||||
error_json: &serde_json::Value,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let json_string = error_json.to_string();
|
||||
let mut response = Response::new(Self::create_full_body(json_string));
|
||||
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
response.headers_mut().insert(
|
||||
hyper::header::CONTENT_TYPE,
|
||||
"application/json".parse().unwrap(),
|
||||
);
|
||||
response
|
||||
}
|
||||
|
||||
/// Create a streaming response from a reqwest response
|
||||
pub async fn create_streaming_response(
|
||||
&self,
|
||||
llm_response: reqwest::Response,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ResponseError> {
|
||||
// Copy headers from the original response
|
||||
let response_headers = llm_response.headers();
|
||||
let mut response_builder = Response::builder();
|
||||
|
||||
let headers = response_builder.headers_mut().ok_or_else(|| {
|
||||
ResponseError::StreamError("Failed to get mutable headers".to_string())
|
||||
})?;
|
||||
|
||||
for (header_name, header_value) in response_headers.iter() {
|
||||
headers.insert(header_name, header_value.clone());
|
||||
}
|
||||
|
||||
// Create channel for async streaming
|
||||
let (tx, rx) = mpsc::channel::<Bytes>(16);
|
||||
|
||||
// Spawn task to stream data
|
||||
tokio::spawn(async move {
|
||||
let mut byte_stream = llm_response.bytes_stream();
|
||||
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
let chunk = match item {
|
||||
Ok(chunk) => chunk,
|
||||
Err(err) => {
|
||||
warn!("Error receiving chunk: {:?}", err);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if tx.send(chunk).await.is_err() {
|
||||
warn!("Receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
|
||||
let stream_body = BoxBody::new(StreamBody::new(stream));
|
||||
|
||||
response_builder
|
||||
.body(stream_body)
|
||||
.map_err(ResponseError::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ResponseHandler {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use hyper::StatusCode;
|
||||
|
||||
#[test]
|
||||
fn test_create_bad_request() {
|
||||
let response = ResponseHandler::create_bad_request("Invalid request");
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_internal_error() {
|
||||
let response = ResponseHandler::create_internal_error("Server error");
|
||||
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_error_response() {
|
||||
let response =
|
||||
ResponseHandler::create_error_response(StatusCode::NOT_FOUND, "Resource not found");
|
||||
assert_eq!(response.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_json_error_response() {
|
||||
let error_json = serde_json::json!({
|
||||
"error": {
|
||||
"type": "TestError",
|
||||
"message": "Test error message"
|
||||
}
|
||||
});
|
||||
|
||||
let response = ResponseHandler::create_json_error_response(&error_json);
|
||||
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
assert_eq!(
|
||||
response.headers().get("content-type").unwrap(),
|
||||
"application/json"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_streaming_response_with_mock() {
|
||||
use mockito::Server;
|
||||
|
||||
let mut server = Server::new_async().await;
|
||||
let mock = server
|
||||
.mock("GET", "/test")
|
||||
.with_status(200)
|
||||
.with_header("content-type", "text/plain")
|
||||
.with_body("streaming response")
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let llm_response = client.get(&(server.url() + "/test")).send().await.unwrap();
|
||||
|
||||
let handler = ResponseHandler::new();
|
||||
let result = handler.create_streaming_response(llm_response).await;
|
||||
|
||||
mock.assert_async().await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let response = result.unwrap();
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert!(response.headers().contains_key("content-type"));
|
||||
}
|
||||
}
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
use brightstaff::handlers::agent_chat_completions::agent_chat;
|
||||
use brightstaff::handlers::chat_completions::chat;
|
||||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::router::llm_router::RouterService;
|
||||
|
|
@ -61,7 +62,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let arch_config = Arc::new(config);
|
||||
|
||||
let llm_providers = Arc::new(RwLock::new(arch_config.llm_providers.clone()));
|
||||
let llm_providers = Arc::new(RwLock::new(arch_config.model_providers.clone()));
|
||||
let agents_list = Arc::new(RwLock::new(arch_config.agents.clone()));
|
||||
let listeners = Arc::new(RwLock::new(arch_config.listeners.clone()));
|
||||
|
||||
debug!(
|
||||
"arch_config: {:?}",
|
||||
|
|
@ -84,11 +87,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let routing_llm_provider = arch_config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.llm_provider.clone())
|
||||
.and_then(|r| r.model_provider.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
||||
arch_config.llm_providers.clone(),
|
||||
arch_config.model_providers.clone(),
|
||||
llm_provider_url.clone() + CHAT_COMPLETIONS_PATH,
|
||||
routing_model_name,
|
||||
routing_llm_provider,
|
||||
|
|
@ -106,12 +109,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let llm_provider_url = llm_provider_url.clone();
|
||||
|
||||
let llm_providers = llm_providers.clone();
|
||||
let agents_list = agents_list.clone();
|
||||
let listeners = listeners.clone();
|
||||
let service = service_fn(move |req| {
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let parent_cx = extract_context_from_request(&req);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
let llm_providers = llm_providers.clone();
|
||||
let model_aliases = Arc::clone(&model_aliases);
|
||||
let agents_list = agents_list.clone();
|
||||
let listeners = listeners.clone();
|
||||
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
|
|
@ -122,8 +129,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await),
|
||||
(&Method::OPTIONS, "/v1/models") => {
|
||||
(&Method::POST, "/agents/v1/chat/completions") => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, req.uri().path());
|
||||
agent_chat(
|
||||
req,
|
||||
router_service,
|
||||
fully_qualified_url,
|
||||
agents_list,
|
||||
listeners,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::GET, "/v1/models" | "/agents/v1/models") => {
|
||||
Ok(list_models(llm_providers).await)
|
||||
}
|
||||
// hack for now to get openw-web-ui to work
|
||||
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => {
|
||||
let mut response = Response::new(empty());
|
||||
*response.status_mut() = StatusCode::NO_CONTENT;
|
||||
response
|
||||
|
|
@ -147,6 +170,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
Ok(response)
|
||||
}
|
||||
_ => {
|
||||
debug!("No route for {} {}", req.method(), req.uri().path());
|
||||
let mut not_found = Response::new(empty());
|
||||
*not_found.status_mut() = StatusCode::NOT_FOUND;
|
||||
Ok(not_found)
|
||||
|
|
|
|||
|
|
@ -79,7 +79,13 @@ impl RouterService {
|
|||
trace_parent: Option<String>,
|
||||
usage_preferences: Option<Vec<ModelUsagePreference>>,
|
||||
) -> Result<Option<(String, String)>> {
|
||||
if !self.llm_usage_defined {
|
||||
if messages.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if (usage_preferences.is_none() || usage_preferences.as_ref().unwrap().len() < 2)
|
||||
&& !self.llm_usage_defined
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ use crate::api::open_ai::{
|
|||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Routing {
|
||||
pub llm_provider: Option<String>,
|
||||
pub model_provider: Option<String>,
|
||||
pub model: Option<String>,
|
||||
}
|
||||
|
||||
|
|
@ -18,11 +18,34 @@ pub struct ModelAlias {
|
|||
pub target: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Agent {
|
||||
pub id: String,
|
||||
pub kind: Option<String>,
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AgentFilterChain {
|
||||
pub id: String,
|
||||
pub default: Option<bool>,
|
||||
pub description: Option<String>,
|
||||
pub filter_chain: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Listener {
|
||||
pub name: String,
|
||||
pub router: Option<String>,
|
||||
pub agents: Option<Vec<AgentFilterChain>>,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Configuration {
|
||||
pub version: String,
|
||||
pub endpoints: Option<HashMap<String, Endpoint>>,
|
||||
pub llm_providers: Vec<LlmProvider>,
|
||||
pub model_providers: Vec<LlmProvider>,
|
||||
pub model_aliases: Option<HashMap<String, ModelAlias>>,
|
||||
pub overrides: Option<Overrides>,
|
||||
pub system_prompt: Option<String>,
|
||||
|
|
@ -33,6 +56,8 @@ pub struct Configuration {
|
|||
pub tracing: Option<Tracing>,
|
||||
pub mode: Option<GatewayMode>,
|
||||
pub routing: Option<Routing>,
|
||||
pub agents: Option<Vec<Agent>>,
|
||||
pub listeners: Vec<Listener>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
|
|
|
|||
|
|
@ -30,3 +30,4 @@ pub const HALLUCINATION_TEMPLATE: &str =
|
|||
pub const OTEL_COLLECTOR_HTTP: &str = "opentelemetry_collector_http";
|
||||
pub const OTEL_POST_PATH: &str = "/v1/traces";
|
||||
pub const LLM_ROUTE_HEADER: &str = "x-arch-llm-route";
|
||||
pub const ENVOY_RETRY_HEADER: &str = "x-envoy-max-retries";
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ impl RootContext for FilterContext {
|
|||
ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default()));
|
||||
self.overrides = Rc::new(config.overrides);
|
||||
|
||||
match config.llm_providers.try_into() {
|
||||
match config.model_providers.try_into() {
|
||||
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
|
||||
Err(err) => panic!("{err}"),
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue