mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add tests and some refactor
This commit is contained in:
parent
08471d8adf
commit
2229f0d4d4
8 changed files with 1004 additions and 276 deletions
50
crates/Cargo.lock
generated
50
crates/Cargo.lock
generated
|
|
@ -104,6 +104,16 @@ version = "1.4.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223"
|
||||
|
||||
[[package]]
|
||||
name = "assert-json-diff"
|
||||
version = "2.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.88"
|
||||
|
|
@ -216,6 +226,7 @@ dependencies = [
|
|||
"http-body-util",
|
||||
"hyper 1.6.0",
|
||||
"hyper-util",
|
||||
"mockito",
|
||||
"opentelemetry",
|
||||
"opentelemetry-http",
|
||||
"opentelemetry-otlp",
|
||||
|
|
@ -321,6 +332,15 @@ version = "0.2.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "67ba02a97a2bd10f4b59b25c7973101c79642302776489e030cd13cdab09ed15"
|
||||
|
||||
[[package]]
|
||||
name = "colored"
|
||||
version = "3.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e"
|
||||
dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "common"
|
||||
version = "0.1.0"
|
||||
|
|
@ -1731,6 +1751,30 @@ dependencies = [
|
|||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mockito"
|
||||
version = "1.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7760e0e418d9b7e5777c0374009ca4c93861b9066f18cb334a20ce50ab63aa48"
|
||||
dependencies = [
|
||||
"assert-json-diff",
|
||||
"bytes",
|
||||
"colored",
|
||||
"futures-util",
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"hyper 1.6.0",
|
||||
"hyper-util",
|
||||
"log",
|
||||
"rand 0.9.1",
|
||||
"regex",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"similar",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "more-asserts"
|
||||
version = "0.3.1"
|
||||
|
|
@ -2817,6 +2861,12 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "similar"
|
||||
version = "2.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa"
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.9"
|
||||
|
|
|
|||
|
|
@ -28,6 +28,12 @@ serde_with = "3.13.0"
|
|||
serde_yaml = "0.9.34"
|
||||
thiserror = "2.0.12"
|
||||
tokio = { version = "1.44.2", features = ["full"] }
|
||||
tokio-stream = "0.1"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
[dev-dependencies]
|
||||
mockito = "1.0"
|
||||
tokio-stream = "0.1.17"
|
||||
tracing = "0.1.41"
|
||||
tracing-opentelemetry = "0.30.0"
|
||||
|
|
|
|||
|
|
@ -1,25 +1,30 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{AgentPipeline, ModelUsagePreference, RoutingPreference};
|
||||
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::{Request, Response};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
use super::agent_selector::{AgentSelector, AgentSelectionError};
|
||||
use super::pipeline_processor::{PipelineProcessor, PipelineError};
|
||||
use super::response_handler::ResponseHandler;
|
||||
|
||||
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||
Full::new(chunk.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
/// Main errors for agent chat completions
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AgentChatError {
|
||||
#[error("Agent selection error: {0}")]
|
||||
Selection(#[from] AgentSelectionError),
|
||||
#[error("Pipeline processing error: {0}")]
|
||||
Pipeline(#[from] PipelineError),
|
||||
#[error("Response handling error: {0}")]
|
||||
Response(#[from] super::response_handler::ResponseError),
|
||||
#[error("Request parsing error: {0}")]
|
||||
RequestParsing(#[from] serde_json::Error),
|
||||
#[error("HTTP error: {0}")]
|
||||
Http(#[from] hyper::Error),
|
||||
}
|
||||
|
||||
pub async fn agent_chat(
|
||||
|
|
@ -29,25 +34,44 @@ pub async fn agent_chat(
|
|||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
// find listener that is running at port 8001 for agents
|
||||
let listener_name = request.headers().get("x-arch-agent-listener-name");
|
||||
match handle_agent_chat(request, router_service, agents_list, listeners).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
warn!("Agent chat error: {}", err);
|
||||
Ok(ResponseHandler::create_internal_error(&format!("Internal error: {}", err)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_agent_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentChatError> {
|
||||
// Initialize services
|
||||
let agent_selector = AgentSelector::new(router_service);
|
||||
let pipeline_processor = PipelineProcessor::default();
|
||||
let response_handler = ResponseHandler::new();
|
||||
|
||||
// Extract listener name from headers
|
||||
let listener_name = request
|
||||
.headers()
|
||||
.get("x-arch-agent-listener-name")
|
||||
.and_then(|name| name.to_str().ok());
|
||||
|
||||
// Find the appropriate listener
|
||||
let listener = {
|
||||
let listeners = listeners.read().await;
|
||||
listeners
|
||||
.iter()
|
||||
.find(|l| {
|
||||
listener_name
|
||||
.and_then(|name| name.to_str().ok())
|
||||
.map(|name| l.name == name)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.cloned()
|
||||
}
|
||||
.unwrap();
|
||||
agent_selector
|
||||
.find_listener(listener_name, &listeners)
|
||||
.await?
|
||||
};
|
||||
|
||||
info!("Handling request for listener: {}", listener.name);
|
||||
|
||||
let mut request_headers = request.headers().clone();
|
||||
// Parse request body
|
||||
let request_headers = request.headers().clone();
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!(
|
||||
|
|
@ -56,269 +80,64 @@ pub async fn agent_chat(
|
|||
);
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest =
|
||||
match serde_json::from_slice(&chat_request_bytes) {
|
||||
Ok(req) => req,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Failed to parse request body as ChatCompletionsRequest: {}",
|
||||
err
|
||||
);
|
||||
let err_msg = format!("Failed to parse request body: {}", err);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
};
|
||||
|
||||
let agent_name_map = {
|
||||
let agents = agents_list.read().await;
|
||||
let agents = agents.as_ref().unwrap();
|
||||
let mut map = std::collections::HashMap::new();
|
||||
for agent in agents.iter() {
|
||||
map.insert(agent.name.clone(), agent.clone());
|
||||
}
|
||||
map
|
||||
};
|
||||
serde_json::from_slice(&chat_request_bytes).map_err(|err| {
|
||||
warn!("Failed to parse request body as ChatCompletionsRequest: {}", err);
|
||||
AgentChatError::RequestParsing(err)
|
||||
})?;
|
||||
|
||||
// Extract trace parent for routing
|
||||
let trace_parent = request_headers
|
||||
.iter()
|
||||
.find(|(ty, _)| ty.as_str() == "traceparent")
|
||||
.find(|(key, _)| key.as_str() == "traceparent")
|
||||
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
|
||||
|
||||
let agents_usage_preferences: Vec<ModelUsagePreference> =
|
||||
convert_agent_description_to_routing_preferences(&listener.agents.as_ref().unwrap());
|
||||
// Select appropriate agent using arch router llm model
|
||||
let selected_agent = agent_selector
|
||||
.select_agent(
|
||||
&chat_completions_request.messages,
|
||||
&listener,
|
||||
trace_parent,
|
||||
)
|
||||
.await?;
|
||||
|
||||
debug!(
|
||||
"Agents usage preferences for agent routing: {:?}",
|
||||
agents_usage_preferences
|
||||
);
|
||||
debug!("Processing agent pipeline: {}", selected_agent.name);
|
||||
|
||||
let agent_pipeline = match agents_usage_preferences.len() > 1 {
|
||||
false => {
|
||||
debug!("Only one agent available, skipping routing");
|
||||
listener.agents.as_ref().unwrap()[0].clone()
|
||||
}
|
||||
true => {
|
||||
let selected_agent = match router_service
|
||||
.determine_route(
|
||||
&chat_completions_request.messages,
|
||||
trace_parent.clone(),
|
||||
Some(agents_usage_preferences),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(route) => {
|
||||
match route {
|
||||
Some((_, agent_name)) => {
|
||||
debug!("Determined agent: {}", agent_name);
|
||||
listener
|
||||
.agents
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.find(|a| a.name == agent_name)
|
||||
.cloned()
|
||||
// selected agent must exist in the agent map
|
||||
.unwrap()
|
||||
}
|
||||
None => {
|
||||
debug!("No agent determined using routing preferences, using default agent");
|
||||
listener
|
||||
.agents
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.find(|a| a.default.unwrap_or(false))
|
||||
.cloned()
|
||||
.unwrap_or_else(|| {
|
||||
warn!(
|
||||
"No default agent found, routing request to first agent: {}",
|
||||
listener.agents.as_ref().unwrap()[0].name
|
||||
);
|
||||
listener.agents.as_ref().unwrap()[0].clone()
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to determine route: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
selected_agent
|
||||
}
|
||||
// Create agent map for pipeline processing
|
||||
let agent_map = {
|
||||
let agents = agents_list.read().await;
|
||||
let agents = agents.as_ref().unwrap();
|
||||
agent_selector.create_agent_map(agents)
|
||||
};
|
||||
|
||||
debug!("Processing agent pipeline: {}", agent_pipeline.name);
|
||||
// Process the filter chain
|
||||
let processed_messages = pipeline_processor
|
||||
.process_filter_chain(
|
||||
&chat_completions_request,
|
||||
&selected_agent,
|
||||
&agent_map,
|
||||
&request_headers,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut chat_completions_history = chat_completions_request.messages.clone();
|
||||
// Get terminal agent and send final response
|
||||
let terminal_agent_name = selected_agent.filter_chain.last().unwrap();
|
||||
let terminal_agent = agent_map.get(terminal_agent_name).unwrap();
|
||||
|
||||
request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
let filter_chain_without_terminal_agent =
|
||||
&agent_pipeline.filter_chain[..agent_pipeline.filter_chain.len() - 1];
|
||||
|
||||
for agent_name in filter_chain_without_terminal_agent {
|
||||
debug!("Processing agent: {}", agent_name);
|
||||
let agent = agent_name_map.get(agent_name).unwrap();
|
||||
debug!("Agent details: {:?}", agent);
|
||||
|
||||
let mut request = chat_completions_request.clone();
|
||||
request.messages = chat_completions_history.clone();
|
||||
|
||||
let request_str = serde_json::to_string(&request).unwrap();
|
||||
debug!("Sending request to agent {}", agent_name);
|
||||
|
||||
let mut agent_request_headers = request_headers.clone();
|
||||
agent_request_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(agent.name.as_str()).unwrap(),
|
||||
);
|
||||
|
||||
let response = match reqwest::Client::new()
|
||||
.post("http://localhost:11000/v1/chat/completions")
|
||||
.headers(agent_request_headers)
|
||||
.body(request_str)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to send request: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
let response_bytes = match response.bytes().await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to read response bytes: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
let chat_completions_response: hermesllm::apis::openai::ChatCompletionsResponse =
|
||||
match serde_json::from_slice(&response_bytes) {
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to parse response body: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
let response_str = chat_completions_response.choices[0]
|
||||
.message
|
||||
.content
|
||||
.clone()
|
||||
.unwrap();
|
||||
|
||||
debug!("Received response from agent {}", agent_name);
|
||||
|
||||
chat_completions_history = serde_json::from_str(response_str.as_str()).unwrap_or(vec![]);
|
||||
}
|
||||
|
||||
let terminal_agent_name = agent_pipeline.filter_chain.last().unwrap();
|
||||
let terminal_agent = agent_name_map.get(terminal_agent_name).unwrap();
|
||||
debug!("Processing terminal agent: {}", terminal_agent_name);
|
||||
debug!("Terminal agent details: {:?}", terminal_agent);
|
||||
|
||||
let mut request = chat_completions_request.clone();
|
||||
request.messages = chat_completions_history.clone();
|
||||
let llm_response = pipeline_processor
|
||||
.send_terminal_request(
|
||||
&processed_messages,
|
||||
&chat_completions_request,
|
||||
terminal_agent,
|
||||
&request_headers,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let request_str = serde_json::to_string(&request).unwrap();
|
||||
debug!("Sending request to agent {}", terminal_agent_name);
|
||||
|
||||
let mut agent_request_headers = request_headers.clone();
|
||||
agent_request_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(terminal_agent.name.as_str()).unwrap(),
|
||||
);
|
||||
|
||||
let llm_response = match reqwest::Client::new()
|
||||
.post("http://localhost:11000/v1/chat/completions")
|
||||
.headers(agent_request_headers)
|
||||
.body(request_str)
|
||||
.send()
|
||||
// Create streaming response
|
||||
response_handler
|
||||
.create_streaming_response(llm_response)
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to send request: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
// copy over the headers from the original response
|
||||
let response_headers = llm_response.headers().clone();
|
||||
let mut response = Response::builder();
|
||||
let headers = response.headers_mut().unwrap();
|
||||
for (header_name, header_value) in response_headers.iter() {
|
||||
headers.insert(header_name, header_value.clone());
|
||||
}
|
||||
|
||||
// channel to create async stream
|
||||
let (tx, rx) = mpsc::channel::<Bytes>(16);
|
||||
|
||||
// Spawn a task to send data as it becomes available
|
||||
tokio::spawn(async move {
|
||||
let mut byte_stream = llm_response.bytes_stream();
|
||||
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
let item = match item {
|
||||
Ok(item) => item,
|
||||
Err(err) => {
|
||||
warn!("Error receiving chunk: {:?}", err);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if tx.send(item).await.is_err() {
|
||||
warn!("Receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
|
||||
|
||||
let stream_body = BoxBody::new(StreamBody::new(stream));
|
||||
|
||||
match response.body(stream_body) {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to create response: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
Ok(internal_error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_agent_description_to_routing_preferences(
|
||||
agents: &Vec<AgentPipeline>,
|
||||
) -> Vec<ModelUsagePreference> {
|
||||
agents
|
||||
.iter()
|
||||
.map(|agent| ModelUsagePreference {
|
||||
model: agent.name.clone(),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
name: agent.name.clone(),
|
||||
description: agent
|
||||
.description
|
||||
.as_ref()
|
||||
.unwrap_or(&"".to_string())
|
||||
.clone(),
|
||||
}],
|
||||
})
|
||||
.collect()
|
||||
.map_err(AgentChatError::from)
|
||||
}
|
||||
|
|
|
|||
296
crates/brightstaff/src/handlers/agent_selector.rs
Normal file
296
crates/brightstaff/src/handlers/agent_selector.rs
Normal file
|
|
@ -0,0 +1,296 @@
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::configuration::{
|
||||
Agent, AgentPipeline, Listener, ModelUsagePreference, RoutingPreference,
|
||||
};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
|
||||
/// Errors that can occur during agent selection
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AgentSelectionError {
|
||||
#[error("Listener not found for name: {0}")]
|
||||
ListenerNotFound(String),
|
||||
#[error("No agents configured for listener: {0}")]
|
||||
NoAgentsConfigured(String),
|
||||
#[error("Routing service error: {0}")]
|
||||
RoutingError(String),
|
||||
#[error("Default agent not found for listener: {0}")]
|
||||
DefaultAgentNotFound(String),
|
||||
}
|
||||
|
||||
/// Service for selecting agents based on routing preferences and listener configuration
|
||||
pub struct AgentSelector {
|
||||
router_service: Arc<RouterService>,
|
||||
}
|
||||
|
||||
impl AgentSelector {
|
||||
pub fn new(router_service: Arc<RouterService>) -> Self {
|
||||
Self { router_service }
|
||||
}
|
||||
|
||||
/// Find listener by name from the request headers
|
||||
pub async fn find_listener(
|
||||
&self,
|
||||
listener_name: Option<&str>,
|
||||
listeners: &[common::configuration::Listener],
|
||||
) -> Result<Listener, AgentSelectionError> {
|
||||
let listener = listeners
|
||||
.iter()
|
||||
.find(|l| listener_name.map(|name| l.name == name).unwrap_or(false))
|
||||
.cloned()
|
||||
.ok_or_else(|| {
|
||||
AgentSelectionError::ListenerNotFound(
|
||||
listener_name.unwrap_or("unknown").to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(listener)
|
||||
}
|
||||
|
||||
/// Create agent name to agent mapping for efficient lookup
|
||||
pub fn create_agent_map(&self, agents: &[Agent]) -> HashMap<String, Agent> {
|
||||
agents
|
||||
.iter()
|
||||
.map(|agent| (agent.name.clone(), agent.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Select appropriate agent based on routing preferences
|
||||
pub async fn select_agent(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
listener: &Listener,
|
||||
trace_parent: Option<String>,
|
||||
) -> Result<AgentPipeline, AgentSelectionError> {
|
||||
let agents = listener
|
||||
.agents
|
||||
.as_ref()
|
||||
.ok_or_else(|| AgentSelectionError::NoAgentsConfigured(listener.name.clone()))?;
|
||||
|
||||
// If only one agent, skip routing
|
||||
if agents.len() == 1 {
|
||||
debug!("Only one agent available, skipping routing");
|
||||
return Ok(agents[0].clone());
|
||||
}
|
||||
|
||||
let usage_preferences = self.convert_agent_description_to_routing_preferences(agents);
|
||||
debug!(
|
||||
"Agents usage preferences for agent routing str: {}",
|
||||
serde_json::to_string(&usage_preferences).unwrap_or_default()
|
||||
);
|
||||
|
||||
match self
|
||||
.router_service
|
||||
.determine_route(messages, trace_parent, Some(usage_preferences))
|
||||
.await
|
||||
{
|
||||
Ok(Some((_, agent_name))) => {
|
||||
debug!("Determined agent: {}", agent_name);
|
||||
let selected_agent = agents
|
||||
.iter()
|
||||
.find(|a| a.name == agent_name)
|
||||
.cloned()
|
||||
.ok_or_else(|| {
|
||||
AgentSelectionError::RoutingError(format!(
|
||||
"Selected agent '{}' not found in listener agents",
|
||||
agent_name
|
||||
))
|
||||
})?;
|
||||
Ok(selected_agent)
|
||||
}
|
||||
Ok(None) => {
|
||||
debug!("No agent determined using routing preferences, using default agent");
|
||||
self.get_default_agent(agents, &listener.name)
|
||||
}
|
||||
Err(err) => Err(AgentSelectionError::RoutingError(err.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the default agent or the first agent if no default is specified
|
||||
fn get_default_agent(
|
||||
&self,
|
||||
agents: &[AgentPipeline],
|
||||
listener_name: &str,
|
||||
) -> Result<AgentPipeline, AgentSelectionError> {
|
||||
agents
|
||||
.iter()
|
||||
.find(|a| a.default.unwrap_or(false))
|
||||
.cloned()
|
||||
.or_else(|| {
|
||||
warn!(
|
||||
"No default agent found, routing request to first agent: {}",
|
||||
agents[0].name
|
||||
);
|
||||
Some(agents[0].clone())
|
||||
})
|
||||
.ok_or_else(|| AgentSelectionError::DefaultAgentNotFound(listener_name.to_string()))
|
||||
}
|
||||
|
||||
/// Convert agent descriptions to routing preferences
|
||||
fn convert_agent_description_to_routing_preferences(
|
||||
&self,
|
||||
agents: &[AgentPipeline],
|
||||
) -> Vec<ModelUsagePreference> {
|
||||
agents
|
||||
.iter()
|
||||
.map(|agent| ModelUsagePreference {
|
||||
model: agent.name.clone(),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
name: agent.name.clone(),
|
||||
description: agent.description.as_ref().unwrap_or(&String::new()).clone(),
|
||||
}],
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::configuration::{AgentPipeline, Listener};
|
||||
|
||||
fn create_test_router_service() -> Arc<RouterService> {
|
||||
Arc::new(RouterService::new(
|
||||
vec![], // empty providers for testing
|
||||
"http://localhost:8080".to_string(),
|
||||
"test-model".to_string(),
|
||||
"test-provider".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn create_test_agent(name: &str, description: &str, is_default: bool) -> AgentPipeline {
|
||||
AgentPipeline {
|
||||
name: name.to_string(),
|
||||
description: Some(description.to_string()),
|
||||
default: Some(is_default),
|
||||
filter_chain: vec![name.to_string()],
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_listener(name: &str, agents: Vec<AgentPipeline>) -> Listener {
|
||||
Listener {
|
||||
name: name.to_string(),
|
||||
agents: Some(agents),
|
||||
port: 8080,
|
||||
router: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_agent_struct(name: &str) -> Agent {
|
||||
Agent {
|
||||
name: name.to_string(),
|
||||
kind: "test".to_string(),
|
||||
endpoint: "http://localhost:8080".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_listener_success() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
|
||||
let listener1 = create_test_listener("test-listener", vec![]);
|
||||
let listener2 = create_test_listener("other-listener", vec![]);
|
||||
let listeners = vec![listener1.clone(), listener2];
|
||||
|
||||
let result = selector
|
||||
.find_listener(Some("test-listener"), &listeners)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().name, "test-listener");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_listener_not_found() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
|
||||
let listeners = vec![create_test_listener("other-listener", vec![])];
|
||||
|
||||
let result = selector
|
||||
.find_listener(Some("nonexistent"), &listeners)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
matches!(
|
||||
result.unwrap_err(),
|
||||
AgentSelectionError::ListenerNotFound(_)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_agent_map() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
|
||||
let agents = vec![
|
||||
create_test_agent_struct("agent1"),
|
||||
create_test_agent_struct("agent2"),
|
||||
];
|
||||
|
||||
let agent_map = selector.create_agent_map(&agents);
|
||||
|
||||
assert_eq!(agent_map.len(), 2);
|
||||
assert!(agent_map.contains_key("agent1"));
|
||||
assert!(agent_map.contains_key("agent2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_agent_description_to_routing_preferences() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
|
||||
let agents = vec![
|
||||
create_test_agent("agent1", "First agent description", true),
|
||||
create_test_agent("agent2", "Second agent description", false),
|
||||
];
|
||||
|
||||
let preferences = selector.convert_agent_description_to_routing_preferences(&agents);
|
||||
|
||||
assert_eq!(preferences.len(), 2);
|
||||
assert_eq!(preferences[0].model, "agent1");
|
||||
assert_eq!(preferences[0].routing_preferences[0].name, "agent1");
|
||||
assert_eq!(
|
||||
preferences[0].routing_preferences[0].description,
|
||||
"First agent description"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_default_agent() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
|
||||
let agents = vec![
|
||||
create_test_agent("agent1", "First agent", false),
|
||||
create_test_agent("agent2", "Default agent", true),
|
||||
create_test_agent("agent3", "Third agent", false),
|
||||
];
|
||||
|
||||
let result = selector.get_default_agent(&agents, "test-listener");
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().name, "agent2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_default_agent_fallback_to_first() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
|
||||
let agents = vec![
|
||||
create_test_agent("agent1", "First agent", false),
|
||||
create_test_agent("agent2", "Second agent", false),
|
||||
];
|
||||
|
||||
let result = selector.get_default_agent(&agents, "test-listener");
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().name, "agent1");
|
||||
}
|
||||
}
|
||||
143
crates/brightstaff/src/handlers/integration_tests.rs
Normal file
143
crates/brightstaff/src/handlers/integration_tests.rs
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent};
|
||||
use hyper::header::HeaderMap;
|
||||
|
||||
use crate::handlers::agent_selector::{AgentSelector, AgentSelectionError};
|
||||
use crate::handlers::pipeline_processor::{PipelineProcessor};
|
||||
use crate::handlers::response_handler::ResponseHandler;
|
||||
use crate::router::llm_router::RouterService;
|
||||
|
||||
/// Integration test that demonstrates the modular agent chat flow
|
||||
/// This test shows how the three main components work together:
|
||||
/// 1. AgentSelector - selects the appropriate agent based on routing
|
||||
/// 2. PipelineProcessor - executes the agent pipeline
|
||||
/// 3. ResponseHandler - handles response streaming
|
||||
#[cfg(test)]
|
||||
mod integration_tests {
|
||||
use super::*;
|
||||
use common::configuration::{Agent, AgentPipeline, Listener};
|
||||
|
||||
fn create_test_router_service() -> Arc<RouterService> {
|
||||
Arc::new(RouterService::new(
|
||||
vec![], // empty providers for testing
|
||||
"http://localhost:8080".to_string(),
|
||||
"test-model".to_string(),
|
||||
"test-provider".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn create_test_message(role: Role, content: &str) -> Message {
|
||||
Message {
|
||||
role,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_modular_agent_chat_flow() {
|
||||
// Setup services
|
||||
let router_service = create_test_router_service();
|
||||
let agent_selector = AgentSelector::new(router_service);
|
||||
let pipeline_processor = PipelineProcessor::default();
|
||||
|
||||
// Create test data
|
||||
let agents = vec![
|
||||
Agent {
|
||||
name: "filter-agent".to_string(),
|
||||
kind: "filter".to_string(),
|
||||
endpoint: "http://localhost:8081".to_string(),
|
||||
},
|
||||
Agent {
|
||||
name: "terminal-agent".to_string(),
|
||||
kind: "terminal".to_string(),
|
||||
endpoint: "http://localhost:8082".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let agent_pipeline = AgentPipeline {
|
||||
name: "test-pipeline".to_string(),
|
||||
filter_chain: vec!["filter-agent".to_string(), "terminal-agent".to_string()],
|
||||
description: Some("Test pipeline".to_string()),
|
||||
default: Some(true),
|
||||
};
|
||||
|
||||
let listener = Listener {
|
||||
name: "test-listener".to_string(),
|
||||
agents: Some(vec![agent_pipeline.clone()]),
|
||||
port: 8080,
|
||||
router: None,
|
||||
};
|
||||
|
||||
let listeners = vec![listener];
|
||||
let messages = vec![create_test_message(Role::User, "Hello world!")];
|
||||
|
||||
// Test 1: Agent Selection
|
||||
let selected_listener = agent_selector
|
||||
.find_listener(Some("test-listener"), &listeners)
|
||||
.await;
|
||||
|
||||
assert!(selected_listener.is_ok());
|
||||
let listener = selected_listener.unwrap();
|
||||
assert_eq!(listener.name, "test-listener");
|
||||
|
||||
// Test 2: Agent Map Creation
|
||||
let agent_map = agent_selector.create_agent_map(&agents);
|
||||
assert_eq!(agent_map.len(), 2);
|
||||
assert!(agent_map.contains_key("filter-agent"));
|
||||
assert!(agent_map.contains_key("terminal-agent"));
|
||||
|
||||
// Test 3: Pipeline Processing (empty filter chain for testing)
|
||||
let request = ChatCompletionsRequest {
|
||||
messages: messages.clone(),
|
||||
model: "test-model".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Create a pipeline with only terminal agent to avoid network calls
|
||||
let test_pipeline = AgentPipeline {
|
||||
name: "test-pipeline".to_string(),
|
||||
filter_chain: vec!["terminal-agent".to_string()],
|
||||
description: None,
|
||||
default: None,
|
||||
};
|
||||
|
||||
let headers = HeaderMap::new();
|
||||
let result = pipeline_processor
|
||||
.process_filter_chain(&request, &test_pipeline, &agent_map, &headers)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let processed_messages = result.unwrap();
|
||||
assert_eq!(processed_messages.len(), 1);
|
||||
|
||||
// Test 4: Error Response Creation
|
||||
let error_response = ResponseHandler::create_bad_request("Test error");
|
||||
assert_eq!(error_response.status(), hyper::StatusCode::BAD_REQUEST);
|
||||
|
||||
println!("✅ All modular components working correctly!");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_error_handling_flow() {
|
||||
let router_service = create_test_router_service();
|
||||
let agent_selector = AgentSelector::new(router_service);
|
||||
|
||||
// Test listener not found
|
||||
let result = agent_selector
|
||||
.find_listener(Some("nonexistent"), &[])
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), AgentSelectionError::ListenerNotFound(_)));
|
||||
|
||||
// Test error response creation
|
||||
let error_response = ResponseHandler::create_internal_error("Pipeline failed");
|
||||
assert_eq!(error_response.status(), hyper::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
println!("✅ Error handling working correctly!");
|
||||
}
|
||||
}
|
||||
|
|
@ -1,3 +1,9 @@
|
|||
pub mod chat_completions;
|
||||
pub mod models;
|
||||
pub mod agent_chat_completions;
|
||||
pub mod agent_selector;
|
||||
pub mod pipeline_processor;
|
||||
pub mod response_handler;
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration_tests;
|
||||
|
|
|
|||
248
crates/brightstaff/src/handlers/pipeline_processor.rs
Normal file
248
crates/brightstaff/src/handlers/pipeline_processor.rs
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use common::configuration::{Agent, AgentPipeline};
|
||||
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
|
||||
use hyper::header::HeaderMap;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
/// Errors that can occur during pipeline processing
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PipelineError {
|
||||
#[error("HTTP request failed: {0}")]
|
||||
RequestFailed(#[from] reqwest::Error),
|
||||
#[error("Failed to parse response: {0}")]
|
||||
ParseError(#[from] serde_json::Error),
|
||||
#[error("Agent '{0}' not found in agent map")]
|
||||
AgentNotFound(String),
|
||||
#[error("No choices in response from agent '{0}'")]
|
||||
NoChoicesInResponse(String),
|
||||
#[error("No content in response from agent '{0}'")]
|
||||
NoContentInResponse(String),
|
||||
}
|
||||
|
||||
/// Service for processing agent pipelines
|
||||
pub struct PipelineProcessor {
|
||||
client: reqwest::Client,
|
||||
llm_endpoint: String,
|
||||
}
|
||||
|
||||
impl Default for PipelineProcessor {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
llm_endpoint: "http://localhost:11000/v1/chat/completions".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PipelineProcessor {
|
||||
pub fn new(llm_endpoint: String) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
llm_endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
/// Process the filter chain of agents (all except the terminal agent)
|
||||
pub async fn process_filter_chain(
|
||||
&self,
|
||||
initial_request: &ChatCompletionsRequest,
|
||||
agent_pipeline: &AgentPipeline,
|
||||
agent_map: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<Vec<Message>, PipelineError> {
|
||||
let mut chat_completions_history = initial_request.messages.clone();
|
||||
|
||||
let filter_chain_without_terminal =
|
||||
&agent_pipeline.filter_chain[..agent_pipeline.filter_chain.len().saturating_sub(1)];
|
||||
|
||||
for agent_name in filter_chain_without_terminal {
|
||||
debug!("Processing filter agent: {}", agent_name);
|
||||
|
||||
let agent = agent_map
|
||||
.get(agent_name)
|
||||
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
|
||||
|
||||
debug!("Agent details: {:?}", agent);
|
||||
|
||||
let response_content = self
|
||||
.send_agent_request(
|
||||
&chat_completions_history,
|
||||
initial_request,
|
||||
agent,
|
||||
request_headers,
|
||||
)
|
||||
.await?;
|
||||
|
||||
debug!("Received response from filter agent {}", agent_name);
|
||||
|
||||
// Parse the response content as new message history
|
||||
chat_completions_history = serde_json::from_str(&response_content).map_err(|err| {
|
||||
warn!(
|
||||
"Failed to parse response from agent {}, response: {}",
|
||||
agent_name, response_content
|
||||
);
|
||||
err
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(chat_completions_history)
|
||||
}
|
||||
|
||||
/// Send request to a specific agent and return the response content
|
||||
async fn send_agent_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
original_request: &ChatCompletionsRequest,
|
||||
agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<String, PipelineError> {
|
||||
let mut request = original_request.clone();
|
||||
request.messages = messages.to_vec();
|
||||
|
||||
let request_body = serde_json::to_string(&request)?;
|
||||
debug!("Sending request to agent {}", agent.name);
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&agent.name)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent.name.clone()))?,
|
||||
);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.llm_endpoint)
|
||||
.headers(agent_headers)
|
||||
.body(request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let response_bytes = response.bytes().await?;
|
||||
|
||||
// Parse the response as JSON to extract the content
|
||||
let response_json: serde_json::Value = serde_json::from_slice(&response_bytes)?;
|
||||
|
||||
let content = response_json
|
||||
.get("choices")
|
||||
.and_then(|choices| choices.as_array())
|
||||
.and_then(|choices| choices.first())
|
||||
.and_then(|choice| choice.get("message"))
|
||||
.and_then(|message| message.get("content"))
|
||||
.and_then(|content| content.as_str())
|
||||
.ok_or_else(|| PipelineError::NoContentInResponse(agent.name.clone()))?
|
||||
.to_string();
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
/// Send request to terminal agent and return the raw response for streaming
|
||||
pub async fn send_terminal_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
original_request: &ChatCompletionsRequest,
|
||||
terminal_agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<reqwest::Response, PipelineError> {
|
||||
let mut request = original_request.clone();
|
||||
request.messages = messages.to_vec();
|
||||
|
||||
let request_body = serde_json::to_string(&request)?;
|
||||
debug!("Sending request to terminal agent {}", terminal_agent.name);
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&terminal_agent.name)
|
||||
.map_err(|_| PipelineError::AgentNotFound(terminal_agent.name.clone()))?,
|
||||
);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.llm_endpoint)
|
||||
.headers(agent_headers)
|
||||
.body(request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use hermesllm::apis::openai::{Message, MessageContent, Role};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn create_test_message(role: Role, content: &str) -> Message {
|
||||
Message {
|
||||
role,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_pipeline(agents: Vec<&str>) -> AgentPipeline {
|
||||
AgentPipeline {
|
||||
name: "test-pipeline".to_string(),
|
||||
filter_chain: agents.iter().map(|s| s.to_string()).collect(),
|
||||
description: None,
|
||||
default: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_process_empty_filter_chain() {
|
||||
let processor = PipelineProcessor::default();
|
||||
let agent_map = HashMap::new();
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let initial_request = ChatCompletionsRequest {
|
||||
messages: vec![create_test_message(Role::User, "Hello")],
|
||||
model: "test-model".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Pipeline with only terminal agent (no filter chain)
|
||||
let pipeline = create_test_pipeline(vec!["terminal-agent"]);
|
||||
|
||||
let result = processor
|
||||
.process_filter_chain(&initial_request, &pipeline, &agent_map, &request_headers)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let messages = result.unwrap();
|
||||
assert_eq!(messages.len(), 1);
|
||||
if let MessageContent::Text(text) = &messages[0].content {
|
||||
assert_eq!(text, "Hello");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_agent_not_found_error() {
|
||||
let processor = PipelineProcessor::default();
|
||||
let agent_map = HashMap::new();
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let initial_request = ChatCompletionsRequest {
|
||||
messages: vec![create_test_message(Role::User, "Hello")],
|
||||
model: "test-model".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
|
||||
|
||||
let result = processor
|
||||
.process_filter_chain(&initial_request, &pipeline, &agent_map, &request_headers)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
matches!(result.unwrap_err(), PipelineError::AgentNotFound(_));
|
||||
}
|
||||
}
|
||||
160
crates/brightstaff/src/handlers/response_handler.rs
Normal file
160
crates/brightstaff/src/handlers/response_handler.rs
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
use bytes::Bytes;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
use hyper::{Response, StatusCode};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::warn;
|
||||
|
||||
/// Errors that can occur during response handling
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ResponseError {
|
||||
#[error("Failed to create response: {0}")]
|
||||
ResponseCreationFailed(#[from] hyper::http::Error),
|
||||
#[error("Stream error: {0}")]
|
||||
StreamError(String),
|
||||
}
|
||||
|
||||
/// Service for handling HTTP responses and streaming
|
||||
pub struct ResponseHandler;
|
||||
|
||||
impl ResponseHandler {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
/// Create a full response body from bytes
|
||||
pub fn create_full_body<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||
Full::new(chunk.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// Create an error response with a given status code and message
|
||||
pub fn create_error_response(
|
||||
status: StatusCode,
|
||||
message: &str,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let mut response = Response::new(Self::create_full_body(message.to_string()));
|
||||
*response.status_mut() = status;
|
||||
response
|
||||
}
|
||||
|
||||
/// Create a bad request response
|
||||
pub fn create_bad_request(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
Self::create_error_response(StatusCode::BAD_REQUEST, message)
|
||||
}
|
||||
|
||||
/// Create an internal server error response
|
||||
pub fn create_internal_error(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
Self::create_error_response(StatusCode::INTERNAL_SERVER_ERROR, message)
|
||||
}
|
||||
|
||||
/// Create a streaming response from a reqwest response
|
||||
pub async fn create_streaming_response(
|
||||
&self,
|
||||
llm_response: reqwest::Response,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ResponseError> {
|
||||
// Copy headers from the original response
|
||||
let response_headers = llm_response.headers();
|
||||
let mut response_builder = Response::builder();
|
||||
|
||||
let headers = response_builder.headers_mut().ok_or_else(|| {
|
||||
ResponseError::StreamError("Failed to get mutable headers".to_string())
|
||||
})?;
|
||||
|
||||
for (header_name, header_value) in response_headers.iter() {
|
||||
headers.insert(header_name, header_value.clone());
|
||||
}
|
||||
|
||||
// Create channel for async streaming
|
||||
let (tx, rx) = mpsc::channel::<Bytes>(16);
|
||||
|
||||
// Spawn task to stream data
|
||||
tokio::spawn(async move {
|
||||
let mut byte_stream = llm_response.bytes_stream();
|
||||
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
let chunk = match item {
|
||||
Ok(chunk) => chunk,
|
||||
Err(err) => {
|
||||
warn!("Error receiving chunk: {:?}", err);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if tx.send(chunk).await.is_err() {
|
||||
warn!("Receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
|
||||
let stream_body = BoxBody::new(StreamBody::new(stream));
|
||||
|
||||
response_builder
|
||||
.body(stream_body)
|
||||
.map_err(ResponseError::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ResponseHandler {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use hyper::StatusCode;
|
||||
|
||||
#[test]
|
||||
fn test_create_bad_request() {
|
||||
let response = ResponseHandler::create_bad_request("Invalid request");
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_internal_error() {
|
||||
let response = ResponseHandler::create_internal_error("Server error");
|
||||
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_error_response() {
|
||||
let response =
|
||||
ResponseHandler::create_error_response(StatusCode::NOT_FOUND, "Resource not found");
|
||||
assert_eq!(response.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_streaming_response_with_mock() {
|
||||
use mockito::Server;
|
||||
|
||||
let mut server = Server::new_async().await;
|
||||
let mock = server
|
||||
.mock("GET", "/test")
|
||||
.with_status(200)
|
||||
.with_header("content-type", "text/plain")
|
||||
.with_body("streaming response")
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let llm_response = client.get(&(server.url() + "/test")).send().await.unwrap();
|
||||
|
||||
let handler = ResponseHandler::new();
|
||||
let result = handler.create_streaming_response(llm_response).await;
|
||||
|
||||
mock.assert_async().await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let response = result.unwrap();
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert!(response.headers().contains_key("content-type"));
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue