add streaming

This commit is contained in:
Adil Hafeez 2025-09-17 09:39:10 -07:00
parent 4588787427
commit 08471d8adf
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
5 changed files with 274 additions and 59 deletions

View file

@ -1,19 +1,14 @@
use std::sync::Arc;
use bytes::Bytes;
use common::api::open_ai::{ChatCompletionsResponse, Choice};
use common::configuration::{AgentPipeline, ModelUsagePreference, RoutingPreference};
use common::consts::{ARCH_PROVIDER_HINT_HEADER, ARCH_UPSTREAM_HOST_HEADER};
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
use hermesllm::apis::openai::ChatCompletionsRequest;
use hermesllm::apis::{Role, Usage};
use hermesllm::clients::SupportedAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use hyper::header::{self};
use hyper::{Request, Response, StatusCode, Uri};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use hyper::{Request, Response, StatusCode};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
@ -30,7 +25,7 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
pub async fn agent_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
full_qualified_llm_provider_url: String,
_: String,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
@ -52,7 +47,6 @@ pub async fn agent_chat(
info!("Handling request for listener: {}", listener.name);
let request_path = request.uri().path().to_string();
let mut request_headers = request.headers().clone();
let chat_request_bytes = request.collect().await?.to_bytes();
@ -163,9 +157,12 @@ pub async fn agent_chat(
request_headers.remove(header::CONTENT_LENGTH);
for agent_name in agent_pipeline.filter_chain {
let filter_chain_without_terminal_agent =
&agent_pipeline.filter_chain[..agent_pipeline.filter_chain.len() - 1];
for agent_name in filter_chain_without_terminal_agent {
debug!("Processing agent: {}", agent_name);
let agent = agent_name_map.get(&agent_name).unwrap();
let agent = agent_name_map.get(agent_name).unwrap();
debug!("Agent details: {:?}", agent);
let mut request = chat_completions_request.clone();
@ -223,41 +220,88 @@ pub async fn agent_chat(
.clone()
.unwrap();
debug!(
"Received response from agent {}",
agent_name
);
debug!("Received response from agent {}", agent_name);
chat_completions_history = serde_json::from_str(response_str.as_str()).unwrap_or(vec![]);
}
let last_response: Option<String> = match chat_completions_history.last() {
Some(msg) => Some(msg.content.clone().to_string()),
None => None,
let terminal_agent_name = agent_pipeline.filter_chain.last().unwrap();
let terminal_agent = agent_name_map.get(terminal_agent_name).unwrap();
debug!("Processing terminal agent: {}", terminal_agent_name);
debug!("Terminal agent details: {:?}", terminal_agent);
let mut request = chat_completions_request.clone();
request.messages = chat_completions_history.clone();
let request_str = serde_json::to_string(&request).unwrap();
debug!("Sending request to agent {}", terminal_agent_name);
let mut agent_request_headers = request_headers.clone();
agent_request_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(terminal_agent.name.as_str()).unwrap(),
);
let llm_response = match reqwest::Client::new()
.post("http://localhost:11000/v1/chat/completions")
.headers(agent_request_headers)
.body(request_str)
.send()
.await
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let chat_completion_response: hermesllm::apis::openai::ChatCompletionsResponse =
hermesllm::apis::openai::ChatCompletionsResponse {
model: "arch-agent".to_string(),
choices: vec![hermesllm::apis::openai::Choice {
message: {
hermesllm::apis::openai::ResponseMessage {
role: hermesllm::apis::openai::Role::Assistant,
content: last_response,
..Default::default()
}
},
..Default::default()
}],
usage: hermesllm::apis::openai::Usage {
..Default::default()
},
..Default::default()
};
// copy over the headers from the original response
let response_headers = llm_response.headers().clone();
let mut response = Response::builder();
let headers = response.headers_mut().unwrap();
for (header_name, header_value) in response_headers.iter() {
headers.insert(header_name, header_value.clone());
}
let response_body = serde_json::to_string(&chat_completion_response).unwrap();
// channel to create async stream
let (tx, rx) = mpsc::channel::<Bytes>(16);
return Ok(Response::new(full(response_body)));
// Spawn a task to send data as it becomes available
tokio::spawn(async move {
let mut byte_stream = llm_response.bytes_stream();
while let Some(item) = byte_stream.next().await {
let item = match item {
Ok(item) => item,
Err(err) => {
warn!("Error receiving chunk: {:?}", err);
break;
}
};
if tx.send(item).await.is_err() {
warn!("Receiver dropped");
break;
}
}
});
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
let stream_body = BoxBody::new(StreamBody::new(stream));
match response.body(stream_body) {
Ok(response) => Ok(response),
Err(err) => {
let err_msg = format!("Failed to create response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
Ok(internal_error)
}
}
}
fn convert_agent_description_to_routing_preferences(