This commit is contained in:
Adil Hafeez 2025-09-11 15:55:25 -07:00
parent 32838584cf
commit 093834bb05
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
14 changed files with 623 additions and 48 deletions

View file

@ -0,0 +1,278 @@
use std::sync::Arc;
use bytes::Bytes;
use common::api::open_ai::{ChatCompletionsResponse, Choice};
use common::configuration::ModelUsagePreference;
use common::consts::ARCH_PROVIDER_HINT_HEADER;
use hermesllm::apis::openai::ChatCompletionsRequest;
use hermesllm::apis::{Role, Usage};
use hermesllm::clients::SupportedAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use hyper::header::{self};
use hyper::{Request, Response, StatusCode};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}
pub async fn agent_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
full_qualified_llm_provider_url: String,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// find listener that is running at port 8001 for agents
let listener = {
let listeners = listeners.read().await;
listeners.iter().find(|l| l.port == 8001).cloned()
}
.unwrap();
let request_path = request.uri().path().to_string();
let mut request_headers = request.headers().clone();
let chat_request_bytes = request.collect().await?.to_bytes();
debug!(
"Received request body (raw utf8): {}",
String::from_utf8_lossy(&chat_request_bytes)
);
let chat_completions_request: ChatCompletionsRequest =
match serde_json::from_slice(&chat_request_bytes) {
Ok(req) => req,
Err(err) => {
warn!(
"Failed to parse request body as ChatCompletionsRequest: {}",
err
);
let err_msg = format!("Failed to parse request body: {}", err);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
};
let agent_name_map = {
let agents = agents_list.read().await;
let agents = agents.as_ref().unwrap();
let mut map = std::collections::HashMap::new();
for agent in agents.iter() {
map.insert(agent.name.clone(), agent.clone());
}
map
};
// find agent to answer the request
let agent_pipeline = listener.agents.as_ref().unwrap()[0].clone(); // for now, just take the first agent pipeline
// process agent pipeline
debug!("Processing agent pipeline: {}", agent_pipeline.name);
let mut chat_completions_history = chat_completions_request.messages.clone();
let mut last_response: Option<String> = None;
for agent_name in agent_pipeline.filter_chain {
debug!("Processing agent: {}", agent_name);
let agent = agent_name_map.get(&agent_name).unwrap();
debug!("Agent details: {:?}", agent);
let path = format!(
"{}/v1/chat/completions",
agent.endpoint.trim_end_matches('/')
);
let mut request = chat_completions_request.clone();
request.messages = chat_completions_history.clone();
let request_str = serde_json::to_string(&request).unwrap();
debug!("Sending request to agent {}: {}", agent_name, request_str);
let response = match reqwest::Client::new()
.post(path)
.body(request_str)
.send()
.await
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let response_bytes = match response.bytes().await {
Ok(bytes) => bytes,
Err(err) => {
let err_msg = format!("Failed to read response bytes: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let chat_completions_response: hermesllm::apis::openai::ChatCompletionsResponse =
match serde_json::from_slice(&response_bytes) {
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to parse response body: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let response_str = chat_completions_response.choices[0]
.message
.content
.clone()
.unwrap();
debug!(
"Received response from agent {}: {}",
agent_name, response_str
);
chat_completions_history = serde_json::from_str(response_str.as_str()).unwrap_or(vec![]);
// chat_completions_history.append(&mut vec![hermesllm::apis::openai::Message {
// role: hermesllm::apis::openai::Role::Assistant,
// content: hermesllm::apis::openai::MessageContent::Text(response_str),
// name: Some(agent_name.clone()),
// tool_calls: None,
// tool_call_id: None,
// }]);
}
let last_response: Option<String> = match chat_completions_history.last() {
Some(msg) => Some(msg.content.clone().to_string()),
None => None,
};
let chat_completion_response: hermesllm::apis::openai::ChatCompletionsResponse =
hermesllm::apis::openai::ChatCompletionsResponse {
model: "arch-agent".to_string(),
choices: vec![hermesllm::apis::openai::Choice {
index: 0,
finish_reason: None,
message: {
hermesllm::apis::openai::ResponseMessage {
role: hermesllm::apis::openai::Role::Assistant,
content: last_response,
refusal: None,
annotations: None,
audio: None,
function_call: None,
tool_calls: None,
}
},
logprobs: None,
}],
usage: hermesllm::apis::openai::Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens_details: None,
completion_tokens_details: None,
},
id: "00".to_string(),
object: "chat.completion".to_string(),
created: 0,
system_fingerprint: None,
service_tier: None,
};
let response_body = serde_json::to_string(&chat_completion_response).unwrap();
return Ok(Response::new(full(response_body)));
// request_headers.insert(
// ARCH_PROVIDER_HINT_HEADER,
// header::HeaderValue::from_str(&model_name).unwrap(),
// );
// if let Some(trace_parent) = trace_parent {
// request_headers.insert(
// header::HeaderName::from_static("traceparent"),
// header::HeaderValue::from_str(&trace_parent).unwrap(),
// );
// }
// // remove content-length header if it exists
// request_headers.remove(header::CONTENT_LENGTH);
// let llm_response = match reqwest::Client::new()
// .post(full_qualified_llm_provider_url)
// .headers(request_headers)
// .body(client_request_bytes_for_upstream)
// .send()
// .await
// {
// Ok(res) => res,
// Err(err) => {
// let err_msg = format!("Failed to send request: {}", err);
// let mut internal_error = Response::new(full(err_msg));
// *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
// return Ok(internal_error);
// }
// };
// // copy over the headers from the original response
// let response_headers = llm_response.headers().clone();
// let mut response = Response::builder();
// let headers = response.headers_mut().unwrap();
// for (header_name, header_value) in response_headers.iter() {
// headers.insert(header_name, header_value.clone());
// }
// // channel to create async stream
// let (tx, rx) = mpsc::channel::<Bytes>(16);
// // Spawn a task to send data as it becomes available
// tokio::spawn(async move {
// let mut byte_stream = llm_response.bytes_stream();
// while let Some(item) = byte_stream.next().await {
// let item = match item {
// Ok(item) => item,
// Err(err) => {
// warn!("Error receiving chunk: {:?}", err);
// break;
// }
// };
// if tx.send(item).await.is_err() {
// warn!("Receiver dropped");
// break;
// }
// }
// });
// let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
// let stream_body = BoxBody::new(StreamBody::new(stream));
// match response.body(stream_body) {
// Ok(response) => Ok(response),
// Err(err) => {
// let err_msg = format!("Failed to create response: {}", err);
// let mut internal_error = Response::new(full(err_msg));
// *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
// Ok(internal_error)
// }
// }
}

View file

@ -1,2 +1,3 @@
pub mod chat_completions;
pub mod models;
pub mod agent_chat_completions;

View file

@ -1,3 +1,4 @@
use brightstaff::handlers::agent_chat_completions::agent_chat;
use brightstaff::handlers::chat_completions::chat;
use brightstaff::handlers::models::list_models;
use brightstaff::router::llm_router::RouterService;
@ -62,6 +63,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let arch_config = Arc::new(config);
let llm_providers = Arc::new(RwLock::new(arch_config.llm_providers.clone()));
let agents_list = Arc::new(RwLock::new(arch_config.agents.clone()));
let listeners = Arc::new(RwLock::new(arch_config.listeners.clone()));
debug!(
"arch_config: {:?}",
@ -103,12 +106,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
let agents_list = agents_list.clone();
let listeners = listeners.clone();
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
let parent_cx = extract_context_from_request(&req);
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
let agents_list = agents_list.clone();
let listeners = listeners.clone();
async move {
match (req.method(), req.uri().path()) {
@ -118,6 +125,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.with_context(parent_cx)
.await
}
(&Method::POST, "/agents/v1/chat/completions") => {
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
agent_chat(req, router_service, fully_qualified_url, agents_list, listeners)
.with_context(parent_cx)
.await
}
(&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await),
(&Method::OPTIONS, "/v1/models") => {
let mut response = Response::new(empty());
@ -143,6 +156,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(response)
}
_ => {
debug!("No route for {} {}", req.method(), req.uri().path());
let mut not_found = Response::new(empty());
*not_found.status_mut() = StatusCode::NOT_FOUND;
Ok(not_found)

View file

@ -13,6 +13,28 @@ pub struct Routing {
pub model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Agent {
pub name: String,
pub kind: String,
pub endpoint: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentPipeline {
pub name: String,
pub description: Option<String>,
pub filter_chain: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Listener {
pub name: String,
pub router: Option<String>,
pub agents: Option<Vec<AgentPipeline>>,
pub port: u16,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
@ -27,6 +49,8 @@ pub struct Configuration {
pub tracing: Option<Tracing>,
pub mode: Option<GatewayMode>,
pub routing: Option<Routing>,
pub agents: Option<Vec<Agent>>,
pub listeners: Vec<Listener>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]