mirror of
https://github.com/katanemo/plano.git
synced 2026-06-23 15:38:07 +02:00
add streaming
This commit is contained in:
parent
4588787427
commit
08471d8adf
5 changed files with 274 additions and 59 deletions
|
|
@ -1,19 +1,14 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use common::api::open_ai::{ChatCompletionsResponse, Choice};
|
||||
use common::configuration::{AgentPipeline, ModelUsagePreference, RoutingPreference};
|
||||
use common::consts::{ARCH_PROVIDER_HINT_HEADER, ARCH_UPSTREAM_HOST_HEADER};
|
||||
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use hermesllm::apis::{Role, Usage};
|
||||
use hermesllm::clients::SupportedAPIs;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response, StatusCode, Uri};
|
||||
use serde::{ser::SerializeMap, Deserialize, Serialize};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
|
|
@ -30,7 +25,7 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
|||
pub async fn agent_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
_: String,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
|
@ -52,7 +47,6 @@ pub async fn agent_chat(
|
|||
|
||||
info!("Handling request for listener: {}", listener.name);
|
||||
|
||||
let request_path = request.uri().path().to_string();
|
||||
let mut request_headers = request.headers().clone();
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
|
|
@ -163,9 +157,12 @@ pub async fn agent_chat(
|
|||
|
||||
request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
for agent_name in agent_pipeline.filter_chain {
|
||||
let filter_chain_without_terminal_agent =
|
||||
&agent_pipeline.filter_chain[..agent_pipeline.filter_chain.len() - 1];
|
||||
|
||||
for agent_name in filter_chain_without_terminal_agent {
|
||||
debug!("Processing agent: {}", agent_name);
|
||||
let agent = agent_name_map.get(&agent_name).unwrap();
|
||||
let agent = agent_name_map.get(agent_name).unwrap();
|
||||
debug!("Agent details: {:?}", agent);
|
||||
|
||||
let mut request = chat_completions_request.clone();
|
||||
|
|
@ -223,41 +220,88 @@ pub async fn agent_chat(
|
|||
.clone()
|
||||
.unwrap();
|
||||
|
||||
debug!(
|
||||
"Received response from agent {}",
|
||||
agent_name
|
||||
);
|
||||
debug!("Received response from agent {}", agent_name);
|
||||
|
||||
chat_completions_history = serde_json::from_str(response_str.as_str()).unwrap_or(vec![]);
|
||||
}
|
||||
|
||||
let last_response: Option<String> = match chat_completions_history.last() {
|
||||
Some(msg) => Some(msg.content.clone().to_string()),
|
||||
None => None,
|
||||
let terminal_agent_name = agent_pipeline.filter_chain.last().unwrap();
|
||||
let terminal_agent = agent_name_map.get(terminal_agent_name).unwrap();
|
||||
debug!("Processing terminal agent: {}", terminal_agent_name);
|
||||
debug!("Terminal agent details: {:?}", terminal_agent);
|
||||
|
||||
let mut request = chat_completions_request.clone();
|
||||
request.messages = chat_completions_history.clone();
|
||||
|
||||
let request_str = serde_json::to_string(&request).unwrap();
|
||||
debug!("Sending request to agent {}", terminal_agent_name);
|
||||
|
||||
let mut agent_request_headers = request_headers.clone();
|
||||
agent_request_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(terminal_agent.name.as_str()).unwrap(),
|
||||
);
|
||||
|
||||
let llm_response = match reqwest::Client::new()
|
||||
.post("http://localhost:11000/v1/chat/completions")
|
||||
.headers(agent_request_headers)
|
||||
.body(request_str)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to send request: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
let chat_completion_response: hermesllm::apis::openai::ChatCompletionsResponse =
|
||||
hermesllm::apis::openai::ChatCompletionsResponse {
|
||||
model: "arch-agent".to_string(),
|
||||
choices: vec![hermesllm::apis::openai::Choice {
|
||||
message: {
|
||||
hermesllm::apis::openai::ResponseMessage {
|
||||
role: hermesllm::apis::openai::Role::Assistant,
|
||||
content: last_response,
|
||||
..Default::default()
|
||||
}
|
||||
},
|
||||
..Default::default()
|
||||
}],
|
||||
usage: hermesllm::apis::openai::Usage {
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
// copy over the headers from the original response
|
||||
let response_headers = llm_response.headers().clone();
|
||||
let mut response = Response::builder();
|
||||
let headers = response.headers_mut().unwrap();
|
||||
for (header_name, header_value) in response_headers.iter() {
|
||||
headers.insert(header_name, header_value.clone());
|
||||
}
|
||||
|
||||
let response_body = serde_json::to_string(&chat_completion_response).unwrap();
|
||||
// channel to create async stream
|
||||
let (tx, rx) = mpsc::channel::<Bytes>(16);
|
||||
|
||||
return Ok(Response::new(full(response_body)));
|
||||
// Spawn a task to send data as it becomes available
|
||||
tokio::spawn(async move {
|
||||
let mut byte_stream = llm_response.bytes_stream();
|
||||
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
let item = match item {
|
||||
Ok(item) => item,
|
||||
Err(err) => {
|
||||
warn!("Error receiving chunk: {:?}", err);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if tx.send(item).await.is_err() {
|
||||
warn!("Receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
|
||||
|
||||
let stream_body = BoxBody::new(StreamBody::new(stream));
|
||||
|
||||
match response.body(stream_body) {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to create response: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
Ok(internal_error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_agent_description_to_routing_preferences(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue