mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
cargo fmt brightstaff
This commit is contained in:
parent
2229f0d4d4
commit
0c6600ac47
6 changed files with 74 additions and 56 deletions
|
|
@ -7,10 +7,10 @@ use http_body_util::BodyExt;
|
|||
use hyper::{Request, Response};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
use super::agent_selector::{AgentSelector, AgentSelectionError};
|
||||
use super::pipeline_processor::{PipelineProcessor, PipelineError};
|
||||
use super::agent_selector::{AgentSelectionError, AgentSelector};
|
||||
use super::pipeline_processor::{PipelineError, PipelineProcessor};
|
||||
use super::response_handler::ResponseHandler;
|
||||
use crate::router::llm_router::RouterService;
|
||||
|
||||
/// Main errors for agent chat completions
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
|
@ -38,7 +38,10 @@ pub async fn agent_chat(
|
|||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
warn!("Agent chat error: {}", err);
|
||||
Ok(ResponseHandler::create_internal_error(&format!("Internal error: {}", err)))
|
||||
Ok(ResponseHandler::create_internal_error(&format!(
|
||||
"Internal error: {}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -81,7 +84,10 @@ async fn handle_agent_chat(
|
|||
|
||||
let chat_completions_request: ChatCompletionsRequest =
|
||||
serde_json::from_slice(&chat_request_bytes).map_err(|err| {
|
||||
warn!("Failed to parse request body as ChatCompletionsRequest: {}", err);
|
||||
warn!(
|
||||
"Failed to parse request body as ChatCompletionsRequest: {}",
|
||||
err
|
||||
);
|
||||
AgentChatError::RequestParsing(err)
|
||||
})?;
|
||||
|
||||
|
|
@ -93,11 +99,7 @@ async fn handle_agent_chat(
|
|||
|
||||
// Select appropriate agent using arch router llm model
|
||||
let selected_agent = agent_selector
|
||||
.select_agent(
|
||||
&chat_completions_request.messages,
|
||||
&listener,
|
||||
trace_parent,
|
||||
)
|
||||
.select_agent(&chat_completions_request.messages, &listener, trace_parent)
|
||||
.await?;
|
||||
|
||||
debug!("Processing agent pipeline: {}", selected_agent.name);
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
use std::sync::Arc;
|
||||
use std::collections::HashMap;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{ModelAlias, ModelUsagePreference};
|
||||
use common::consts::ARCH_PROVIDER_HINT_HEADER;
|
||||
|
|
@ -11,6 +9,8 @@ use http_body_util::{BodyExt, Full, StreamBody};
|
|||
use hyper::body::Frame;
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
|
|
@ -30,14 +30,19 @@ pub async fn chat(
|
|||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
||||
let request_path = request.uri().path().to_string();
|
||||
let mut request_headers = request.headers().clone();
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes));
|
||||
debug!(
|
||||
"Received request body (raw utf8): {}",
|
||||
String::from_utf8_lossy(&chat_request_bytes)
|
||||
);
|
||||
|
||||
let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) {
|
||||
let mut client_request = match ProviderRequestType::try_from((
|
||||
&chat_request_bytes[..],
|
||||
&SupportedAPIs::from_endpoint(request_path.as_str()).unwrap(),
|
||||
)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse request as ProviderRequestType: {}", err);
|
||||
|
|
@ -77,7 +82,10 @@ pub async fn chat(
|
|||
|
||||
// Convert to ChatCompletionsRequest regardless of input type (clone to avoid moving original)
|
||||
let chat_completions_request_for_arch_router: ChatCompletionsRequest =
|
||||
match ProviderRequestType::try_from((client_request, &SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions))) {
|
||||
match ProviderRequestType::try_from((
|
||||
client_request,
|
||||
&SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions),
|
||||
)) {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
|
||||
Ok(ProviderRequestType::MessagesRequest(_)) => {
|
||||
// This should not happen after conversion to OpenAI format
|
||||
|
|
@ -86,9 +94,12 @@ pub async fn chat(
|
|||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
},
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Failed to convert request to ChatCompletionsRequest: {}", err);
|
||||
warn!(
|
||||
"Failed to convert request to ChatCompletionsRequest: {}",
|
||||
err
|
||||
);
|
||||
let err_msg = format!("Failed to convert request: {}", err);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
|
|
@ -106,24 +117,22 @@ pub async fn chat(
|
|||
.find(|(ty, _)| ty.as_str() == "traceparent")
|
||||
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
|
||||
|
||||
let usage_preferences_str: Option<String> =
|
||||
routing_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.get("archgw_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.get("archgw_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
|
||||
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
|
||||
.as_ref()
|
||||
.and_then(|s| serde_yaml::from_str(s).ok());
|
||||
|
||||
let latest_message_for_log =
|
||||
chat_completions_request_for_arch_router
|
||||
.messages
|
||||
.last()
|
||||
.map_or("None".to_string(), |msg| {
|
||||
msg.content.to_string().replace('\n', "\\n")
|
||||
});
|
||||
let latest_message_for_log = chat_completions_request_for_arch_router
|
||||
.messages
|
||||
.last()
|
||||
.map_or("None".to_string(), |msg| {
|
||||
msg.content.to_string().replace('\n', "\\n")
|
||||
});
|
||||
|
||||
const MAX_MESSAGE_LENGTH: usize = 50;
|
||||
let latest_message_for_log = if latest_message_for_log.len() > MAX_MESSAGE_LENGTH {
|
||||
|
|
@ -152,12 +161,11 @@ pub async fn chat(
|
|||
Ok(route) => match route {
|
||||
Some((_, model_name)) => model_name,
|
||||
None => {
|
||||
debug!(
|
||||
debug!(
|
||||
"No route determined, using default model from request: {}",
|
||||
chat_completions_request_for_arch_router.model
|
||||
);
|
||||
chat_completions_request_for_arch_router.model.clone()
|
||||
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
use hyper::header::HeaderMap;
|
||||
|
||||
use crate::handlers::agent_selector::{AgentSelector, AgentSelectionError};
|
||||
use crate::handlers::pipeline_processor::{PipelineProcessor};
|
||||
use crate::handlers::agent_selector::{AgentSelectionError, AgentSelector};
|
||||
use crate::handlers::pipeline_processor::PipelineProcessor;
|
||||
use crate::handlers::response_handler::ResponseHandler;
|
||||
use crate::router::llm_router::RouterService;
|
||||
|
||||
|
|
@ -127,16 +127,20 @@ mod integration_tests {
|
|||
let agent_selector = AgentSelector::new(router_service);
|
||||
|
||||
// Test listener not found
|
||||
let result = agent_selector
|
||||
.find_listener(Some("nonexistent"), &[])
|
||||
.await;
|
||||
let result = agent_selector.find_listener(Some("nonexistent"), &[]).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), AgentSelectionError::ListenerNotFound(_)));
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
AgentSelectionError::ListenerNotFound(_)
|
||||
));
|
||||
|
||||
// Test error response creation
|
||||
let error_response = ResponseHandler::create_internal_error("Pipeline failed");
|
||||
assert_eq!(error_response.status(), hyper::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
assert_eq!(
|
||||
error_response.status(),
|
||||
hyper::StatusCode::INTERNAL_SERVER_ERROR
|
||||
);
|
||||
|
||||
println!("✅ Error handling working correctly!");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
pub mod chat_completions;
|
||||
pub mod models;
|
||||
pub mod agent_chat_completions;
|
||||
pub mod agent_selector;
|
||||
pub mod chat_completions;
|
||||
pub mod models;
|
||||
pub mod pipeline_processor;
|
||||
pub mod response_handler;
|
||||
|
||||
|
|
|
|||
|
|
@ -71,8 +71,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
&serde_json::to_string(arch_config.as_ref()).unwrap()
|
||||
);
|
||||
|
||||
let llm_provider_url = env::var("LLM_PROVIDER_ENDPOINT")
|
||||
.unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
let llm_provider_url =
|
||||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
|
||||
info!("llm provider url: {}", llm_provider_url);
|
||||
info!("listening on http://{}", bind_address);
|
||||
|
|
@ -99,7 +99,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let model_aliases = Arc::new(arch_config.model_aliases.clone());
|
||||
|
||||
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
let peer_addr = stream.peer_addr()?;
|
||||
|
|
@ -113,7 +112,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let agents_list = agents_list.clone();
|
||||
let listeners = listeners.clone();
|
||||
let service = service_fn(move |req| {
|
||||
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let parent_cx = extract_context_from_request(&req);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
|
|
@ -125,16 +123,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, req.uri().path());
|
||||
chat(req, router_service, fully_qualified_url, model_aliases)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::POST, "/agents/v1/chat/completions") => {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
|
||||
agent_chat(req, router_service, fully_qualified_url, agents_list, listeners)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, req.uri().path());
|
||||
agent_chat(
|
||||
req,
|
||||
router_service,
|
||||
fully_qualified_url,
|
||||
agents_list,
|
||||
listeners,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await),
|
||||
(&Method::OPTIONS, "/v1/models") => {
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use common::{
|
||||
configuration::{ModelUsagePreference, RoutingPreference},
|
||||
};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, MessageContent, Message, Role};
|
||||
use common::configuration::{ModelUsagePreference, RoutingPreference};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue