mirror of
https://github.com/katanemo/plano.git
synced 2026-06-29 15:49:40 +02:00
fix more
This commit is contained in:
parent
32838584cf
commit
093834bb05
14 changed files with 623 additions and 48 deletions
278
crates/brightstaff/src/handlers/agent_chat_completions.rs
Normal file
278
crates/brightstaff/src/handlers/agent_chat_completions.rs
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use common::api::open_ai::{ChatCompletionsResponse, Choice};
|
||||
use common::configuration::ModelUsagePreference;
|
||||
use common::consts::ARCH_PROVIDER_HINT_HEADER;
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use hermesllm::apis::{Role, Usage};
|
||||
use hermesllm::clients::SupportedAPIs;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use serde::{ser::SerializeMap, Deserialize, Serialize};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
|
||||
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||
Full::new(chunk.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
pub async fn agent_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
// find listener that is running at port 8001 for agents
|
||||
let listener = {
|
||||
let listeners = listeners.read().await;
|
||||
listeners.iter().find(|l| l.port == 8001).cloned()
|
||||
}
|
||||
.unwrap();
|
||||
|
||||
let request_path = request.uri().path().to_string();
|
||||
let mut request_headers = request.headers().clone();
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!(
|
||||
"Received request body (raw utf8): {}",
|
||||
String::from_utf8_lossy(&chat_request_bytes)
|
||||
);
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest =
|
||||
match serde_json::from_slice(&chat_request_bytes) {
|
||||
Ok(req) => req,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Failed to parse request body as ChatCompletionsRequest: {}",
|
||||
err
|
||||
);
|
||||
let err_msg = format!("Failed to parse request body: {}", err);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
};
|
||||
|
||||
let agent_name_map = {
|
||||
let agents = agents_list.read().await;
|
||||
let agents = agents.as_ref().unwrap();
|
||||
let mut map = std::collections::HashMap::new();
|
||||
for agent in agents.iter() {
|
||||
map.insert(agent.name.clone(), agent.clone());
|
||||
}
|
||||
map
|
||||
};
|
||||
|
||||
// find agent to answer the request
|
||||
let agent_pipeline = listener.agents.as_ref().unwrap()[0].clone(); // for now, just take the first agent pipeline
|
||||
|
||||
// process agent pipeline
|
||||
|
||||
debug!("Processing agent pipeline: {}", agent_pipeline.name);
|
||||
|
||||
let mut chat_completions_history = chat_completions_request.messages.clone();
|
||||
let mut last_response: Option<String> = None;
|
||||
|
||||
for agent_name in agent_pipeline.filter_chain {
|
||||
debug!("Processing agent: {}", agent_name);
|
||||
let agent = agent_name_map.get(&agent_name).unwrap();
|
||||
debug!("Agent details: {:?}", agent);
|
||||
|
||||
let path = format!(
|
||||
"{}/v1/chat/completions",
|
||||
agent.endpoint.trim_end_matches('/')
|
||||
);
|
||||
|
||||
let mut request = chat_completions_request.clone();
|
||||
request.messages = chat_completions_history.clone();
|
||||
|
||||
let request_str = serde_json::to_string(&request).unwrap();
|
||||
debug!("Sending request to agent {}: {}", agent_name, request_str);
|
||||
|
||||
let response = match reqwest::Client::new()
|
||||
.post(path)
|
||||
.body(request_str)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to send request: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
let response_bytes = match response.bytes().await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to read response bytes: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
let chat_completions_response: hermesllm::apis::openai::ChatCompletionsResponse =
|
||||
match serde_json::from_slice(&response_bytes) {
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to parse response body: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
let response_str = chat_completions_response.choices[0]
|
||||
.message
|
||||
.content
|
||||
.clone()
|
||||
.unwrap();
|
||||
|
||||
debug!(
|
||||
"Received response from agent {}: {}",
|
||||
agent_name, response_str
|
||||
);
|
||||
|
||||
chat_completions_history = serde_json::from_str(response_str.as_str()).unwrap_or(vec![]);
|
||||
|
||||
// chat_completions_history.append(&mut vec![hermesllm::apis::openai::Message {
|
||||
// role: hermesllm::apis::openai::Role::Assistant,
|
||||
// content: hermesllm::apis::openai::MessageContent::Text(response_str),
|
||||
// name: Some(agent_name.clone()),
|
||||
// tool_calls: None,
|
||||
// tool_call_id: None,
|
||||
// }]);
|
||||
}
|
||||
|
||||
let last_response: Option<String> = match chat_completions_history.last() {
|
||||
Some(msg) => Some(msg.content.clone().to_string()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let chat_completion_response: hermesllm::apis::openai::ChatCompletionsResponse =
|
||||
hermesllm::apis::openai::ChatCompletionsResponse {
|
||||
model: "arch-agent".to_string(),
|
||||
choices: vec![hermesllm::apis::openai::Choice {
|
||||
index: 0,
|
||||
finish_reason: None,
|
||||
message: {
|
||||
hermesllm::apis::openai::ResponseMessage {
|
||||
role: hermesllm::apis::openai::Role::Assistant,
|
||||
content: last_response,
|
||||
refusal: None,
|
||||
annotations: None,
|
||||
audio: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
},
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: hermesllm::apis::openai::Usage {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
},
|
||||
id: "00".to_string(),
|
||||
object: "chat.completion".to_string(),
|
||||
created: 0,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
};
|
||||
|
||||
let response_body = serde_json::to_string(&chat_completion_response).unwrap();
|
||||
|
||||
return Ok(Response::new(full(response_body)));
|
||||
|
||||
// request_headers.insert(
|
||||
// ARCH_PROVIDER_HINT_HEADER,
|
||||
// header::HeaderValue::from_str(&model_name).unwrap(),
|
||||
// );
|
||||
|
||||
// if let Some(trace_parent) = trace_parent {
|
||||
// request_headers.insert(
|
||||
// header::HeaderName::from_static("traceparent"),
|
||||
// header::HeaderValue::from_str(&trace_parent).unwrap(),
|
||||
// );
|
||||
// }
|
||||
// // remove content-length header if it exists
|
||||
// request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
// let llm_response = match reqwest::Client::new()
|
||||
// .post(full_qualified_llm_provider_url)
|
||||
// .headers(request_headers)
|
||||
// .body(client_request_bytes_for_upstream)
|
||||
// .send()
|
||||
// .await
|
||||
// {
|
||||
// Ok(res) => res,
|
||||
// Err(err) => {
|
||||
// let err_msg = format!("Failed to send request: {}", err);
|
||||
// let mut internal_error = Response::new(full(err_msg));
|
||||
// *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
// return Ok(internal_error);
|
||||
// }
|
||||
// };
|
||||
|
||||
// // copy over the headers from the original response
|
||||
// let response_headers = llm_response.headers().clone();
|
||||
// let mut response = Response::builder();
|
||||
// let headers = response.headers_mut().unwrap();
|
||||
// for (header_name, header_value) in response_headers.iter() {
|
||||
// headers.insert(header_name, header_value.clone());
|
||||
// }
|
||||
|
||||
// // channel to create async stream
|
||||
// let (tx, rx) = mpsc::channel::<Bytes>(16);
|
||||
|
||||
// // Spawn a task to send data as it becomes available
|
||||
// tokio::spawn(async move {
|
||||
// let mut byte_stream = llm_response.bytes_stream();
|
||||
|
||||
// while let Some(item) = byte_stream.next().await {
|
||||
// let item = match item {
|
||||
// Ok(item) => item,
|
||||
// Err(err) => {
|
||||
// warn!("Error receiving chunk: {:?}", err);
|
||||
// break;
|
||||
// }
|
||||
// };
|
||||
|
||||
// if tx.send(item).await.is_err() {
|
||||
// warn!("Receiver dropped");
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
// });
|
||||
|
||||
// let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
|
||||
|
||||
// let stream_body = BoxBody::new(StreamBody::new(stream));
|
||||
|
||||
// match response.body(stream_body) {
|
||||
// Ok(response) => Ok(response),
|
||||
// Err(err) => {
|
||||
// let err_msg = format!("Failed to create response: {}", err);
|
||||
// let mut internal_error = Response::new(full(err_msg));
|
||||
// *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
// Ok(internal_error)
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
|
@ -1,2 +1,3 @@
|
|||
pub mod chat_completions;
|
||||
pub mod models;
|
||||
pub mod agent_chat_completions;
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
use brightstaff::handlers::agent_chat_completions::agent_chat;
|
||||
use brightstaff::handlers::chat_completions::chat;
|
||||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::router::llm_router::RouterService;
|
||||
|
|
@ -62,6 +63,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let arch_config = Arc::new(config);
|
||||
|
||||
let llm_providers = Arc::new(RwLock::new(arch_config.llm_providers.clone()));
|
||||
let agents_list = Arc::new(RwLock::new(arch_config.agents.clone()));
|
||||
let listeners = Arc::new(RwLock::new(arch_config.listeners.clone()));
|
||||
|
||||
debug!(
|
||||
"arch_config: {:?}",
|
||||
|
|
@ -103,12 +106,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let llm_provider_url = llm_provider_url.clone();
|
||||
|
||||
let llm_providers = llm_providers.clone();
|
||||
let agents_list = agents_list.clone();
|
||||
let listeners = listeners.clone();
|
||||
let service = service_fn(move |req| {
|
||||
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let parent_cx = extract_context_from_request(&req);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
let llm_providers = llm_providers.clone();
|
||||
let agents_list = agents_list.clone();
|
||||
let listeners = listeners.clone();
|
||||
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
|
|
@ -118,6 +125,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::POST, "/agents/v1/chat/completions") => {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
|
||||
agent_chat(req, router_service, fully_qualified_url, agents_list, listeners)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await),
|
||||
(&Method::OPTIONS, "/v1/models") => {
|
||||
let mut response = Response::new(empty());
|
||||
|
|
@ -143,6 +156,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
Ok(response)
|
||||
}
|
||||
_ => {
|
||||
debug!("No route for {} {}", req.method(), req.uri().path());
|
||||
let mut not_found = Response::new(empty());
|
||||
*not_found.status_mut() = StatusCode::NOT_FOUND;
|
||||
Ok(not_found)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,28 @@ pub struct Routing {
|
|||
pub model: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Agent {
|
||||
pub name: String,
|
||||
pub kind: String,
|
||||
pub endpoint: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AgentPipeline {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub filter_chain: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Listener {
|
||||
pub name: String,
|
||||
pub router: Option<String>,
|
||||
pub agents: Option<Vec<AgentPipeline>>,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Configuration {
|
||||
pub version: String,
|
||||
|
|
@ -27,6 +49,8 @@ pub struct Configuration {
|
|||
pub tracing: Option<Tracing>,
|
||||
pub mode: Option<GatewayMode>,
|
||||
pub routing: Option<Routing>,
|
||||
pub agents: Option<Vec<Agent>>,
|
||||
pub listeners: Vec<Listener>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue