Introduce brightstaff a new terminal service for llm routing (#477)

This commit is contained in:
Adil Hafeez 2025-05-19 09:59:22 -07:00 committed by GitHub
parent 1f95fac4af
commit 27c0f2fdce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 2817 additions and 150 deletions

View file

@ -0,0 +1,168 @@
use std::sync::Arc;
use bytes::Bytes;
use common::api::open_ai::ChatCompletionsRequest;
use common::consts::ARCH_PROVIDER_HINT_HEADER;
use common::utils::shorten_string;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use hyper::header::{self};
use hyper::{Request, Response, StatusCode};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::{info, warn};
use crate::router::llm_router::RouterService;
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}
pub async fn chat_completions(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
llm_provider_endpoint: String,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let mut request_headers = request.headers().clone();
let chat_request_bytes = request.collect().await?.to_bytes();
let chat_completion_request: ChatCompletionsRequest =
match serde_json::from_slice(&chat_request_bytes) {
Ok(request) => request,
Err(err) => {
let err_msg = format!("Failed to parse request body: {}", err);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
};
info!(
"request body received: {}",
shorten_string(&serde_json::to_string(&chat_completion_request).unwrap())
);
let trace_parent = request_headers
.iter()
.find(|(ty, _)| ty.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
let selected_llm = match router_service
.determine_route(&chat_completion_request.messages, trace_parent.clone())
.await
{
Ok(route) => route,
Err(err) => {
let err_msg = format!("Failed to determine route: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
info!(
"sending request to llm provider: {} with llm model: {:?}",
llm_provider_endpoint, selected_llm
);
if let Some(trace_parent) = trace_parent {
request_headers.insert(
header::HeaderName::from_static("traceparent"),
header::HeaderValue::from_str(&trace_parent).unwrap(),
);
}
if let Some(selected_llm) = selected_llm {
request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(&selected_llm).unwrap(),
);
}
let llm_response = match reqwest::Client::new()
.post(llm_provider_endpoint)
.headers(request_headers)
.body(chat_request_bytes)
.send()
.await
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
// copy over the headers from the original response
let response_headers = llm_response.headers().clone();
let mut response = Response::builder();
let headers = response.headers_mut().unwrap();
for (header_name, header_value) in response_headers.iter() {
headers.insert(header_name, header_value.clone());
}
if chat_completion_request.stream {
// channel to create async stream
let (tx, rx) = mpsc::channel::<Bytes>(16);
// Spawn a task to send data as it becomes available
tokio::spawn(async move {
let mut byte_stream = llm_response.bytes_stream();
while let Some(item) = byte_stream.next().await {
let item = match item {
Ok(item) => item,
Err(err) => {
warn!("Error receiving chunk: {:?}", err);
break;
}
};
if tx.send(item).await.is_err() {
warn!("Receiver dropped");
break;
}
}
});
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
let stream_body = BoxBody::new(StreamBody::new(stream));
match response.body(stream_body) {
Ok(response) => Ok(response),
Err(err) => {
let err_msg = format!("Failed to create response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
Ok(internal_error)
}
}
} else {
let body = match llm_response.text().await {
Ok(body) => body,
Err(err) => {
let err_msg = format!("Failed to read response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
match response.body(full(body)) {
Ok(response) => Ok(response),
Err(err) => {
let err_msg = format!("Failed to create response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
Ok(internal_error)
}
}
}
}

View file

@ -0,0 +1 @@
pub mod chat_completions;

View file

@ -0,0 +1,2 @@
pub mod handlers;
pub mod router;

View file

@ -0,0 +1,157 @@
use brightstaff::handlers::chat_completions::chat_completions;
use brightstaff::router::llm_router::RouterService;
use bytes::Bytes;
use common::configuration::Configuration;
use common::utils::shorten_string;
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use opentelemetry::global::BoxedTracer;
use opentelemetry::trace::FutureExt;
use opentelemetry::{
global,
trace::{SpanKind, Tracer},
Context,
};
use opentelemetry_http::HeaderExtractor;
use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::SdkTracerProvider};
use opentelemetry_stdout::SpanExporter;
use std::sync::{Arc, OnceLock};
use std::{env, fs};
use tokio::net::TcpListener;
use tracing::info;
use tracing_subscriber::EnvFilter;
pub mod router;
const BIND_ADDRESS: &str = "0.0.0.0:9091";
fn get_tracer() -> &'static BoxedTracer {
static TRACER: OnceLock<BoxedTracer> = OnceLock::new();
TRACER.get_or_init(|| global::tracer("archgw/router"))
}
// Utility function to extract the context from the incoming request headers
fn extract_context_from_request(req: &Request<Incoming>) -> Context {
global::get_text_map_propagator(|propagator| {
propagator.extract(&HeaderExtractor(req.headers()))
})
}
fn init_tracer() -> SdkTracerProvider {
global::set_text_map_propagator(TraceContextPropagator::new());
// Install stdout exporter pipeline to be able to retrieve the collected spans.
// For the demonstration, use `Sampler::AlwaysOn` sampler to sample all traces.
let provider = SdkTracerProvider::builder()
.with_simple_exporter(SpanExporter::default())
.build();
global::set_tracer_provider(provider.clone());
provider
}
fn empty() -> BoxBody<Bytes, hyper::Error> {
Empty::<Bytes>::new()
.map_err(|never| match never {})
.boxed()
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let _tracer_provider = init_tracer();
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
)
.init();
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
//loading arch_config.yaml file
let arch_config_path =
env::var("ARCH_CONFIG_PATH").unwrap_or_else(|_| "./arch_config.yaml".to_string());
info!("Loading arch_config.yaml from {}", arch_config_path);
let config_contents =
fs::read_to_string(&arch_config_path).expect("Failed to read arch_config.yaml");
let config: Configuration =
serde_yaml::from_str(&config_contents).expect("Failed to parse arch_config.yaml");
let arch_config = Arc::new(config);
info!(
"arch_config: {:?}",
shorten_string(&serde_json::to_string(arch_config.as_ref()).unwrap())
);
let llm_provider_endpoint = env::var("LLM_PROVIDER_ENDPOINT")
.unwrap_or_else(|_| "http://localhost:12001/v1/chat/completions".to_string());
info!("llm provider endpoint: {}", llm_provider_endpoint);
info!("Listening on http://{}", bind_address);
let listener = TcpListener::bind(bind_address).await?;
// if routing is null then return gpt-4o as model name
let model = arch_config.routing.as_ref().map_or_else(
|| "gpt-4o".to_string(),
|routing| routing.model.clone(),
);
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
arch_config.llm_providers.clone(),
llm_provider_endpoint.clone(),
model,
));
loop {
let (stream, _) = listener.accept().await?;
let peer_addr = stream.peer_addr()?;
let io = TokioIo::new(stream);
let router_service = Arc::clone(&router_service);
let llm_provider_endpoint = llm_provider_endpoint.clone();
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
let parent_cx = extract_context_from_request(&req);
info!("parent_cx: {:?}", parent_cx);
let tracer = get_tracer();
let _span = tracer
.span_builder("request")
.with_kind(SpanKind::Server)
.start_with_context(tracer, &parent_cx);
let llm_provider_endpoint = llm_provider_endpoint.clone();
async move {
match (req.method(), req.uri().path()) {
(&Method::POST, "/v1/chat/completions") => {
chat_completions(req, router_service, llm_provider_endpoint)
.with_context(parent_cx)
.await
}
_ => {
let mut not_found = Response::new(empty());
*not_found.status_mut() = StatusCode::NOT_FOUND;
Ok(not_found)
}
}
}
});
tokio::task::spawn(async move {
info!("Accepted connection from {:?}", peer_addr);
if let Err(err) = http1::Builder::new()
// .serve_connection(io, service_fn(chat_completion))
.serve_connection(io, service)
.await
{
info!("Error serving connection: {:?}", err);
}
});
}
}

View file

@ -0,0 +1,151 @@
use std::sync::Arc;
use common::{
api::open_ai::{ChatCompletionsResponse, Message},
configuration::LlmProvider,
consts::ARCH_PROVIDER_HINT_HEADER,
utils::shorten_string,
};
use hyper::header;
use thiserror::Error;
use tracing::{info, warn};
use super::router_model::RouterModel;
pub struct RouterService {
router_url: String,
client: reqwest::Client,
router_model: Arc<dyn RouterModel>,
routing_model_name: String,
llm_usage_defined: bool,
}
#[derive(Debug, Error)]
pub enum RoutingError {
#[error("Failed to send request: {0}")]
RequestError(#[from] reqwest::Error),
#[error("Failed to parse JSON: {0}, JSON: {1}")]
JsonError(serde_json::Error, String),
#[error("Router model error: {0}")]
RouterModelError(#[from] super::router_model::RoutingModelError),
}
pub type Result<T> = std::result::Result<T, RoutingError>;
impl RouterService {
pub fn new(
providers: Vec<LlmProvider>,
router_url: String,
routing_model_name: String,
) -> Self {
let providers_with_usage = providers
.iter()
.filter(|provider| provider.usage.is_some())
.cloned()
.collect::<Vec<LlmProvider>>();
// convert the llm_providers to yaml string but only include name and usage
let llm_providers_with_usage_yaml = providers_with_usage
.iter()
.map(|provider| {
format!(
"- name: {}\n description: {}",
provider.name,
provider.usage.as_ref().unwrap_or(&"".to_string())
)
})
.collect::<Vec<String>>()
.join("\n");
info!(
"llm_providers from config with usage: {}...",
shorten_string(&llm_providers_with_usage_yaml.replace("\n", "\\n"))
);
let router_model = Arc::new(super::router_model_v1::RouterModelV1::new(
llm_providers_with_usage_yaml.clone(),
routing_model_name.clone(),
));
RouterService {
router_url,
client: reqwest::Client::new(),
router_model,
routing_model_name,
llm_usage_defined: !providers_with_usage.is_empty(),
}
}
pub async fn determine_route(
&self,
messages: &[Message],
trace_parent: Option<String>,
) -> Result<Option<String>> {
if !self.llm_usage_defined {
return Ok(None);
}
let router_request = self.router_model.generate_request(messages);
info!(
"router_request: {}",
shorten_string(&serde_json::to_string(&router_request).unwrap()),
);
let mut llm_route_request_headers = header::HeaderMap::new();
llm_route_request_headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
llm_route_request_headers.insert(
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
header::HeaderValue::from_str(&self.routing_model_name).unwrap(),
);
if let Some(trace_parent) = trace_parent {
llm_route_request_headers.insert(
header::HeaderName::from_static("traceparent"),
header::HeaderValue::from_str(&trace_parent).unwrap(),
);
}
let res = self
.client
.post(&self.router_url)
.headers(llm_route_request_headers)
.body(serde_json::to_string(&router_request).unwrap())
.send()
.await?;
let body = res.text().await?;
let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) {
Ok(response) => response,
Err(err) => {
warn!(
"Failed to parse JSON: {}. Body: {}",
err,
&serde_json::to_string(&body).unwrap()
);
return Err(RoutingError::JsonError(
err,
format!("Failed to parse JSON: {}", body),
));
}
};
let selected_llm = self.router_model.parse_response(
chat_completion_response.choices[0]
.message
.content
.as_ref()
.unwrap(),
)?;
Ok(selected_llm)
}
}

View file

@ -0,0 +1,3 @@
pub mod llm_router;
pub mod router_model;
pub mod router_model_v1;

View file

@ -0,0 +1,15 @@
use common::api::open_ai::{ChatCompletionsRequest, Message};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum RoutingModelError {
#[error("Failed to parse JSON: {0}")]
JsonError(#[from] serde_json::Error),
}
pub type Result<T> = std::result::Result<T, RoutingModelError>;
pub trait RouterModel: Send + Sync {
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest;
fn parse_response(&self, content: &str) -> Result<Option<String>>;
}

View file

@ -0,0 +1,251 @@
use common::{
api::open_ai::{ChatCompletionsRequest, Message},
consts::{SYSTEM_ROLE, USER_ROLE},
};
use serde::{Deserialize, Serialize};
use tracing::info;
use super::router_model::{RouterModel, RoutingModelError};
pub const ARCH_ROUTER_V1_SYSTEM_PROMPT: &str = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
{routes}
</routes>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant, response with empty route {"route": ""}.
2. If the user request is full fill and user thank or ending the conversation , response with empty route {"route": ""}.
3. Understand user latest intent and find the best match route in <routes></routes> xml tags.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
<conversation>
{conversation}
</conversation>
"#;
pub type Result<T> = std::result::Result<T, RoutingModelError>;
pub struct RouterModelV1 {
llm_providers_with_usage_yaml: String,
routing_model: String,
}
impl RouterModelV1 {
pub fn new(llm_providers_with_usage_yaml: String, routing_model: String) -> Self {
RouterModelV1 {
llm_providers_with_usage_yaml,
routing_model,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LlmRouterResponse {
pub route: Option<String>,
}
impl RouterModel for RouterModelV1 {
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest {
let messages_str = messages
.iter()
.filter(|m| m.role != SYSTEM_ROLE)
.map(|m| {
let content_json_str = serde_json::to_string(&m.content).unwrap_or_default();
format!("{}: {}", m.role, content_json_str)
})
.collect::<Vec<String>>()
.join("\n");
let message = ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", &self.llm_providers_with_usage_yaml)
.replace("{conversation}", messages_str.as_str());
ChatCompletionsRequest {
model: self.routing_model.clone(),
messages: vec![Message {
content: Some(message),
role: USER_ROLE.to_string(),
model: None,
tool_calls: None,
tool_call_id: None,
}],
tools: None,
stream: false,
stream_options: None,
metadata: None,
}
}
fn parse_response(&self, content: &str) -> Result<Option<String>> {
if content.is_empty() {
return Ok(None);
}
let router_resp_fixed = fix_json_response(content);
info!(
"router response (fixed): {}",
router_resp_fixed.replace("\n", "\\n")
);
let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?;
let selected_llm = router_response.route.unwrap_or_default().to_string();
if selected_llm.is_empty() {
return Ok(None);
}
Ok(Some(selected_llm))
}
}
fn fix_json_response(body: &str) -> String {
let mut updated_body = body.to_string();
updated_body = updated_body.replace("'", "\"");
if updated_body.contains("\\n") {
updated_body = updated_body.replace("\\n", "");
}
if updated_body.starts_with("```json") {
updated_body = updated_body
.strip_prefix("```json")
.unwrap_or(&updated_body)
.to_string();
}
if updated_body.ends_with("```") {
updated_body = updated_body
.strip_suffix("```")
.unwrap_or(&updated_body)
.to_string();
}
updated_body
}
impl std::fmt::Debug for dyn RouterModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RouterModel")
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn test_system_prompt_format() {
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
route1: description1
route2: description2
</routes>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant, response with empty route {"route": ""}.
2. If the user request is full fill and user thank or ending the conversation , response with empty route {"route": ""}.
3. Understand user latest intent and find the best match route in <routes></routes> xml tags.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
<conversation>
user: "Hello, I want to book a flight."
assistant: "Sure, where would you like to go?"
user: "seattle"
</conversation>
"#;
let routes_yaml = "route1: description1\nroute2: description2";
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone());
let messages = vec![
Message {
role: "system".to_string(),
content: Some("You are a helpful assistant.".to_string()),
..Default::default()
},
Message {
role: "user".to_string(),
content: Some("Hello, I want to book a flight.".to_string()),
..Default::default()
},
Message {
role: "assistant".to_string(),
content: Some("Sure, where would you like to go?".to_string()),
..Default::default()
},
Message {
role: "user".to_string(),
content: Some("seattle".to_string()),
..Default::default()
},
];
let req = router.generate_request(&messages);
let prompt = req.messages[0].content.as_ref().unwrap();
println!("Prompt: {}", prompt);
assert_eq!(expected_prompt, prompt);
}
}
#[test]
fn test_parse_response() {
let router = RouterModelV1::new(
"route1: description1\nroute2: description2".to_string(),
"test-model".to_string(),
);
// Case 1: Valid JSON with non-empty route
let input = r#"{"route": "route1"}"#;
let result = router.parse_response(input).unwrap();
assert_eq!(result, Some("route1".to_string()));
// Case 2: Valid JSON with empty route
let input = r#"{"route": ""}"#;
let result = router.parse_response(input).unwrap();
assert_eq!(result, None);
// Case 3: Valid JSON with null route
let input = r#"{"route": null}"#;
let result = router.parse_response(input).unwrap();
assert_eq!(result, None);
// Case 4: JSON missing route field
let input = r#"{}"#;
let result = router.parse_response(input).unwrap();
assert_eq!(result, None);
// Case 4.1: empty string
let input = r#""#;
let result = router.parse_response(input).unwrap();
assert_eq!(result, None);
// Case 5: Malformed JSON
let input = r#"{"route": "route1""#; // missing closing }
let result = router.parse_response(input);
assert!(result.is_err());
// Case 6: Single quotes and \n in JSON
let input = "{'route': 'route2'}\\n";
let result = router.parse_response(input).unwrap();
assert_eq!(result, Some("route2".to_string()));
// Case 7: Code block marker
let input = "```json\n{\"route\": \"route1\"}\n```";
let result = router.parse_response(input).unwrap();
assert_eq!(result, Some("route1".to_string()));
}