mirror of
https://github.com/katanemo/plano.git
synced 2026-06-20 15:28:07 +02:00
Introduce brightstaff a new terminal service for llm routing (#477)
This commit is contained in:
parent
1f95fac4af
commit
27c0f2fdce
36 changed files with 2817 additions and 150 deletions
1736
crates/Cargo.lock
generated
1736
crates/Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,3 +1,3 @@
|
|||
[workspace]
|
||||
resolver = "2"
|
||||
members = ["llm_gateway", "prompt_gateway", "common"]
|
||||
members = ["llm_gateway", "prompt_gateway", "common", "brightstaff"]
|
||||
|
|
|
|||
32
crates/brightstaff/Cargo.toml
Normal file
32
crates/brightstaff/Cargo.toml
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
[package]
|
||||
name = "brightstaff"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
bytes = "1.10.1"
|
||||
common = { version = "0.1.0", path = "../common" }
|
||||
eventsource-client = "0.15.0"
|
||||
eventsource-stream = "0.2.3"
|
||||
futures = "0.3.31"
|
||||
futures-util = "0.3.31"
|
||||
http-body = "1.0.1"
|
||||
http-body-util = "0.1.3"
|
||||
hyper = { version = "1.6.0", features = ["full"] }
|
||||
hyper-util = "0.1.11"
|
||||
opentelemetry = "0.29.1"
|
||||
opentelemetry-http = "0.29.0"
|
||||
opentelemetry-otlp = {version="0.29.0", features=["trace", "tonic", "grpc-tonic"]}
|
||||
opentelemetry-stdout = "0.29.0"
|
||||
opentelemetry_sdk = "0.29.0"
|
||||
pretty_assertions = "1.4.1"
|
||||
reqwest = { version = "0.12.15", features = ["stream"] }
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
serde_json = "1.0.140"
|
||||
serde_yaml = "0.9.34"
|
||||
thiserror = "2.0.12"
|
||||
tokio = { version = "1.44.2", features = ["full"] }
|
||||
tokio-stream = "0.1.17"
|
||||
tracing = "0.1.41"
|
||||
tracing-opentelemetry = "0.30.0"
|
||||
tracing-subscriber = { version = "0.3.19", features = ["env-filter", "fmt"] }
|
||||
168
crates/brightstaff/src/handlers/chat_completions.rs
Normal file
168
crates/brightstaff/src/handlers/chat_completions.rs
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use common::api::open_ai::ChatCompletionsRequest;
|
||||
use common::consts::ARCH_PROVIDER_HINT_HEADER;
|
||||
use common::utils::shorten_string;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
|
||||
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||
Full::new(chunk.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
pub async fn chat_completions(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
llm_provider_endpoint: String,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let mut request_headers = request.headers().clone();
|
||||
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
let chat_completion_request: ChatCompletionsRequest =
|
||||
match serde_json::from_slice(&chat_request_bytes) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to parse request body: {}", err);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
};
|
||||
|
||||
info!(
|
||||
"request body received: {}",
|
||||
shorten_string(&serde_json::to_string(&chat_completion_request).unwrap())
|
||||
);
|
||||
|
||||
let trace_parent = request_headers
|
||||
.iter()
|
||||
.find(|(ty, _)| ty.as_str() == "traceparent")
|
||||
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
|
||||
|
||||
let selected_llm = match router_service
|
||||
.determine_route(&chat_completion_request.messages, trace_parent.clone())
|
||||
.await
|
||||
{
|
||||
Ok(route) => route,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to determine route: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
info!(
|
||||
"sending request to llm provider: {} with llm model: {:?}",
|
||||
llm_provider_endpoint, selected_llm
|
||||
);
|
||||
|
||||
if let Some(trace_parent) = trace_parent {
|
||||
request_headers.insert(
|
||||
header::HeaderName::from_static("traceparent"),
|
||||
header::HeaderValue::from_str(&trace_parent).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(selected_llm) = selected_llm {
|
||||
request_headers.insert(
|
||||
ARCH_PROVIDER_HINT_HEADER,
|
||||
header::HeaderValue::from_str(&selected_llm).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
let llm_response = match reqwest::Client::new()
|
||||
.post(llm_provider_endpoint)
|
||||
.headers(request_headers)
|
||||
.body(chat_request_bytes)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to send request: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
// copy over the headers from the original response
|
||||
let response_headers = llm_response.headers().clone();
|
||||
let mut response = Response::builder();
|
||||
let headers = response.headers_mut().unwrap();
|
||||
for (header_name, header_value) in response_headers.iter() {
|
||||
headers.insert(header_name, header_value.clone());
|
||||
}
|
||||
|
||||
if chat_completion_request.stream {
|
||||
// channel to create async stream
|
||||
let (tx, rx) = mpsc::channel::<Bytes>(16);
|
||||
|
||||
// Spawn a task to send data as it becomes available
|
||||
tokio::spawn(async move {
|
||||
let mut byte_stream = llm_response.bytes_stream();
|
||||
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
let item = match item {
|
||||
Ok(item) => item,
|
||||
Err(err) => {
|
||||
warn!("Error receiving chunk: {:?}", err);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if tx.send(item).await.is_err() {
|
||||
warn!("Receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
|
||||
|
||||
let stream_body = BoxBody::new(StreamBody::new(stream));
|
||||
|
||||
match response.body(stream_body) {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to create response: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
Ok(internal_error)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let body = match llm_response.text().await {
|
||||
Ok(body) => body,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to read response: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
match response.body(full(body)) {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to create response: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
Ok(internal_error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
1
crates/brightstaff/src/handlers/mod.rs
Normal file
1
crates/brightstaff/src/handlers/mod.rs
Normal file
|
|
@ -0,0 +1 @@
|
|||
pub mod chat_completions;
|
||||
2
crates/brightstaff/src/lib.rs
Normal file
2
crates/brightstaff/src/lib.rs
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
pub mod handlers;
|
||||
pub mod router;
|
||||
157
crates/brightstaff/src/main.rs
Normal file
157
crates/brightstaff/src/main.rs
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
use brightstaff::handlers::chat_completions::chat_completions;
|
||||
use brightstaff::router::llm_router::RouterService;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::Configuration;
|
||||
use common::utils::shorten_string;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::server::conn::http1;
|
||||
use hyper::service::service_fn;
|
||||
use hyper::{Method, Request, Response, StatusCode};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use opentelemetry::global::BoxedTracer;
|
||||
use opentelemetry::trace::FutureExt;
|
||||
use opentelemetry::{
|
||||
global,
|
||||
trace::{SpanKind, Tracer},
|
||||
Context,
|
||||
};
|
||||
use opentelemetry_http::HeaderExtractor;
|
||||
use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::SdkTracerProvider};
|
||||
use opentelemetry_stdout::SpanExporter;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::{env, fs};
|
||||
use tokio::net::TcpListener;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
pub mod router;
|
||||
|
||||
const BIND_ADDRESS: &str = "0.0.0.0:9091";
|
||||
|
||||
fn get_tracer() -> &'static BoxedTracer {
|
||||
static TRACER: OnceLock<BoxedTracer> = OnceLock::new();
|
||||
TRACER.get_or_init(|| global::tracer("archgw/router"))
|
||||
}
|
||||
|
||||
// Utility function to extract the context from the incoming request headers
|
||||
fn extract_context_from_request(req: &Request<Incoming>) -> Context {
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
propagator.extract(&HeaderExtractor(req.headers()))
|
||||
})
|
||||
}
|
||||
|
||||
fn init_tracer() -> SdkTracerProvider {
|
||||
global::set_text_map_propagator(TraceContextPropagator::new());
|
||||
// Install stdout exporter pipeline to be able to retrieve the collected spans.
|
||||
// For the demonstration, use `Sampler::AlwaysOn` sampler to sample all traces.
|
||||
let provider = SdkTracerProvider::builder()
|
||||
.with_simple_exporter(SpanExporter::default())
|
||||
.build();
|
||||
|
||||
global::set_tracer_provider(provider.clone());
|
||||
provider
|
||||
}
|
||||
|
||||
fn empty() -> BoxBody<Bytes, hyper::Error> {
|
||||
Empty::<Bytes>::new()
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let _tracer_provider = init_tracer();
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
|
||||
)
|
||||
.init();
|
||||
|
||||
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
|
||||
|
||||
//loading arch_config.yaml file
|
||||
let arch_config_path =
|
||||
env::var("ARCH_CONFIG_PATH").unwrap_or_else(|_| "./arch_config.yaml".to_string());
|
||||
info!("Loading arch_config.yaml from {}", arch_config_path);
|
||||
|
||||
let config_contents =
|
||||
fs::read_to_string(&arch_config_path).expect("Failed to read arch_config.yaml");
|
||||
|
||||
let config: Configuration =
|
||||
serde_yaml::from_str(&config_contents).expect("Failed to parse arch_config.yaml");
|
||||
|
||||
let arch_config = Arc::new(config);
|
||||
|
||||
info!(
|
||||
"arch_config: {:?}",
|
||||
shorten_string(&serde_json::to_string(arch_config.as_ref()).unwrap())
|
||||
);
|
||||
|
||||
let llm_provider_endpoint = env::var("LLM_PROVIDER_ENDPOINT")
|
||||
.unwrap_or_else(|_| "http://localhost:12001/v1/chat/completions".to_string());
|
||||
|
||||
info!("llm provider endpoint: {}", llm_provider_endpoint);
|
||||
info!("Listening on http://{}", bind_address);
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
|
||||
|
||||
// if routing is null then return gpt-4o as model name
|
||||
let model = arch_config.routing.as_ref().map_or_else(
|
||||
|| "gpt-4o".to_string(),
|
||||
|routing| routing.model.clone(),
|
||||
);
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
||||
arch_config.llm_providers.clone(),
|
||||
llm_provider_endpoint.clone(),
|
||||
model,
|
||||
));
|
||||
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
let peer_addr = stream.peer_addr()?;
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let llm_provider_endpoint = llm_provider_endpoint.clone();
|
||||
|
||||
let service = service_fn(move |req| {
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let parent_cx = extract_context_from_request(&req);
|
||||
info!("parent_cx: {:?}", parent_cx);
|
||||
let tracer = get_tracer();
|
||||
let _span = tracer
|
||||
.span_builder("request")
|
||||
.with_kind(SpanKind::Server)
|
||||
.start_with_context(tracer, &parent_cx);
|
||||
let llm_provider_endpoint = llm_provider_endpoint.clone();
|
||||
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::POST, "/v1/chat/completions") => {
|
||||
chat_completions(req, router_service, llm_provider_endpoint)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
_ => {
|
||||
let mut not_found = Response::new(empty());
|
||||
*not_found.status_mut() = StatusCode::NOT_FOUND;
|
||||
Ok(not_found)
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
info!("Accepted connection from {:?}", peer_addr);
|
||||
if let Err(err) = http1::Builder::new()
|
||||
// .serve_connection(io, service_fn(chat_completion))
|
||||
.serve_connection(io, service)
|
||||
.await
|
||||
{
|
||||
info!("Error serving connection: {:?}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
151
crates/brightstaff/src/router/llm_router.rs
Normal file
151
crates/brightstaff/src/router/llm_router.rs
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use common::{
|
||||
api::open_ai::{ChatCompletionsResponse, Message},
|
||||
configuration::LlmProvider,
|
||||
consts::ARCH_PROVIDER_HINT_HEADER,
|
||||
utils::shorten_string,
|
||||
};
|
||||
use hyper::header;
|
||||
use thiserror::Error;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use super::router_model::RouterModel;
|
||||
|
||||
pub struct RouterService {
|
||||
router_url: String,
|
||||
client: reqwest::Client,
|
||||
router_model: Arc<dyn RouterModel>,
|
||||
routing_model_name: String,
|
||||
llm_usage_defined: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RoutingError {
|
||||
#[error("Failed to send request: {0}")]
|
||||
RequestError(#[from] reqwest::Error),
|
||||
|
||||
#[error("Failed to parse JSON: {0}, JSON: {1}")]
|
||||
JsonError(serde_json::Error, String),
|
||||
|
||||
#[error("Router model error: {0}")]
|
||||
RouterModelError(#[from] super::router_model::RoutingModelError),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, RoutingError>;
|
||||
|
||||
impl RouterService {
|
||||
pub fn new(
|
||||
providers: Vec<LlmProvider>,
|
||||
router_url: String,
|
||||
routing_model_name: String,
|
||||
) -> Self {
|
||||
let providers_with_usage = providers
|
||||
.iter()
|
||||
.filter(|provider| provider.usage.is_some())
|
||||
.cloned()
|
||||
.collect::<Vec<LlmProvider>>();
|
||||
|
||||
// convert the llm_providers to yaml string but only include name and usage
|
||||
let llm_providers_with_usage_yaml = providers_with_usage
|
||||
.iter()
|
||||
.map(|provider| {
|
||||
format!(
|
||||
"- name: {}\n description: {}",
|
||||
provider.name,
|
||||
provider.usage.as_ref().unwrap_or(&"".to_string())
|
||||
)
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
info!(
|
||||
"llm_providers from config with usage: {}...",
|
||||
shorten_string(&llm_providers_with_usage_yaml.replace("\n", "\\n"))
|
||||
);
|
||||
|
||||
let router_model = Arc::new(super::router_model_v1::RouterModelV1::new(
|
||||
llm_providers_with_usage_yaml.clone(),
|
||||
routing_model_name.clone(),
|
||||
));
|
||||
|
||||
RouterService {
|
||||
router_url,
|
||||
client: reqwest::Client::new(),
|
||||
router_model,
|
||||
routing_model_name,
|
||||
llm_usage_defined: !providers_with_usage.is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn determine_route(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
trace_parent: Option<String>,
|
||||
) -> Result<Option<String>> {
|
||||
|
||||
if !self.llm_usage_defined {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let router_request = self.router_model.generate_request(messages);
|
||||
|
||||
info!(
|
||||
"router_request: {}",
|
||||
shorten_string(&serde_json::to_string(&router_request).unwrap()),
|
||||
);
|
||||
|
||||
let mut llm_route_request_headers = header::HeaderMap::new();
|
||||
llm_route_request_headers.insert(
|
||||
header::CONTENT_TYPE,
|
||||
header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
llm_route_request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
|
||||
header::HeaderValue::from_str(&self.routing_model_name).unwrap(),
|
||||
);
|
||||
|
||||
if let Some(trace_parent) = trace_parent {
|
||||
llm_route_request_headers.insert(
|
||||
header::HeaderName::from_static("traceparent"),
|
||||
header::HeaderValue::from_str(&trace_parent).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
let res = self
|
||||
.client
|
||||
.post(&self.router_url)
|
||||
.headers(llm_route_request_headers)
|
||||
.body(serde_json::to_string(&router_request).unwrap())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let body = res.text().await?;
|
||||
|
||||
let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) {
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Failed to parse JSON: {}. Body: {}",
|
||||
err,
|
||||
&serde_json::to_string(&body).unwrap()
|
||||
);
|
||||
return Err(RoutingError::JsonError(
|
||||
err,
|
||||
format!("Failed to parse JSON: {}", body),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let selected_llm = self.router_model.parse_response(
|
||||
chat_completion_response.choices[0]
|
||||
.message
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
)?;
|
||||
|
||||
Ok(selected_llm)
|
||||
}
|
||||
}
|
||||
3
crates/brightstaff/src/router/mod.rs
Normal file
3
crates/brightstaff/src/router/mod.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
pub mod llm_router;
|
||||
pub mod router_model;
|
||||
pub mod router_model_v1;
|
||||
15
crates/brightstaff/src/router/router_model.rs
Normal file
15
crates/brightstaff/src/router/router_model.rs
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
use common::api::open_ai::{ChatCompletionsRequest, Message};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RoutingModelError {
|
||||
#[error("Failed to parse JSON: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, RoutingModelError>;
|
||||
|
||||
pub trait RouterModel: Send + Sync {
|
||||
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest;
|
||||
fn parse_response(&self, content: &str) -> Result<Option<String>>;
|
||||
}
|
||||
251
crates/brightstaff/src/router/router_model_v1.rs
Normal file
251
crates/brightstaff/src/router/router_model_v1.rs
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
use common::{
|
||||
api::open_ai::{ChatCompletionsRequest, Message},
|
||||
consts::{SYSTEM_ROLE, USER_ROLE},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
|
||||
use super::router_model::{RouterModel, RoutingModelError};
|
||||
|
||||
pub const ARCH_ROUTER_V1_SYSTEM_PROMPT: &str = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
{routes}
|
||||
</routes>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant, response with empty route {"route": ""}.
|
||||
2. If the user request is full fill and user thank or ending the conversation , response with empty route {"route": ""}.
|
||||
3. Understand user latest intent and find the best match route in <routes></routes> xml tags.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
|
||||
|
||||
<conversation>
|
||||
{conversation}
|
||||
</conversation>
|
||||
"#;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, RoutingModelError>;
|
||||
|
||||
pub struct RouterModelV1 {
|
||||
llm_providers_with_usage_yaml: String,
|
||||
routing_model: String,
|
||||
}
|
||||
|
||||
impl RouterModelV1 {
|
||||
pub fn new(llm_providers_with_usage_yaml: String, routing_model: String) -> Self {
|
||||
RouterModelV1 {
|
||||
llm_providers_with_usage_yaml,
|
||||
routing_model,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct LlmRouterResponse {
|
||||
pub route: Option<String>,
|
||||
}
|
||||
|
||||
impl RouterModel for RouterModelV1 {
|
||||
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest {
|
||||
let messages_str = messages
|
||||
.iter()
|
||||
.filter(|m| m.role != SYSTEM_ROLE)
|
||||
.map(|m| {
|
||||
let content_json_str = serde_json::to_string(&m.content).unwrap_or_default();
|
||||
format!("{}: {}", m.role, content_json_str)
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
let message = ARCH_ROUTER_V1_SYSTEM_PROMPT
|
||||
.replace("{routes}", &self.llm_providers_with_usage_yaml)
|
||||
.replace("{conversation}", messages_str.as_str());
|
||||
|
||||
ChatCompletionsRequest {
|
||||
model: self.routing_model.clone(),
|
||||
messages: vec![Message {
|
||||
content: Some(message),
|
||||
role: USER_ROLE.to_string(),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
tools: None,
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(&self, content: &str) -> Result<Option<String>> {
|
||||
if content.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
let router_resp_fixed = fix_json_response(content);
|
||||
info!(
|
||||
"router response (fixed): {}",
|
||||
router_resp_fixed.replace("\n", "\\n")
|
||||
);
|
||||
let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?;
|
||||
|
||||
let selected_llm = router_response.route.unwrap_or_default().to_string();
|
||||
|
||||
if selected_llm.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(selected_llm))
|
||||
}
|
||||
}
|
||||
|
||||
fn fix_json_response(body: &str) -> String {
|
||||
let mut updated_body = body.to_string();
|
||||
|
||||
updated_body = updated_body.replace("'", "\"");
|
||||
|
||||
if updated_body.contains("\\n") {
|
||||
updated_body = updated_body.replace("\\n", "");
|
||||
}
|
||||
|
||||
if updated_body.starts_with("```json") {
|
||||
updated_body = updated_body
|
||||
.strip_prefix("```json")
|
||||
.unwrap_or(&updated_body)
|
||||
.to_string();
|
||||
}
|
||||
|
||||
if updated_body.ends_with("```") {
|
||||
updated_body = updated_body
|
||||
.strip_suffix("```")
|
||||
.unwrap_or(&updated_body)
|
||||
.to_string();
|
||||
}
|
||||
|
||||
updated_body
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn RouterModel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "RouterModel")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn test_system_prompt_format() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
route1: description1
|
||||
route2: description2
|
||||
</routes>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant, response with empty route {"route": ""}.
|
||||
2. If the user request is full fill and user thank or ending the conversation , response with empty route {"route": ""}.
|
||||
3. Understand user latest intent and find the best match route in <routes></routes> xml tags.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
|
||||
|
||||
<conversation>
|
||||
user: "Hello, I want to book a flight."
|
||||
assistant: "Sure, where would you like to go?"
|
||||
user: "seattle"
|
||||
</conversation>
|
||||
"#;
|
||||
|
||||
let routes_yaml = "route1: description1\nroute2: description2";
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone());
|
||||
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "system".to_string(),
|
||||
content: Some("You are a helpful assistant.".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("Hello, I want to book a flight.".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Sure, where would you like to go?".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("seattle".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
];
|
||||
|
||||
let req = router.generate_request(&messages);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
|
||||
println!("Prompt: {}", prompt);
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_response() {
|
||||
let router = RouterModelV1::new(
|
||||
"route1: description1\nroute2: description2".to_string(),
|
||||
"test-model".to_string(),
|
||||
);
|
||||
|
||||
// Case 1: Valid JSON with non-empty route
|
||||
let input = r#"{"route": "route1"}"#;
|
||||
let result = router.parse_response(input).unwrap();
|
||||
assert_eq!(result, Some("route1".to_string()));
|
||||
|
||||
// Case 2: Valid JSON with empty route
|
||||
let input = r#"{"route": ""}"#;
|
||||
let result = router.parse_response(input).unwrap();
|
||||
assert_eq!(result, None);
|
||||
|
||||
// Case 3: Valid JSON with null route
|
||||
let input = r#"{"route": null}"#;
|
||||
let result = router.parse_response(input).unwrap();
|
||||
assert_eq!(result, None);
|
||||
|
||||
// Case 4: JSON missing route field
|
||||
let input = r#"{}"#;
|
||||
let result = router.parse_response(input).unwrap();
|
||||
assert_eq!(result, None);
|
||||
|
||||
// Case 4.1: empty string
|
||||
let input = r#""#;
|
||||
let result = router.parse_response(input).unwrap();
|
||||
assert_eq!(result, None);
|
||||
|
||||
// Case 5: Malformed JSON
|
||||
let input = r#"{"route": "route1""#; // missing closing }
|
||||
let result = router.parse_response(input);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Case 6: Single quotes and \n in JSON
|
||||
let input = "{'route': 'route2'}\\n";
|
||||
let result = router.parse_response(input).unwrap();
|
||||
assert_eq!(result, Some("route2".to_string()));
|
||||
|
||||
// Case 7: Code block marker
|
||||
let input = "```json\n{\"route\": \"route1\"}\n```";
|
||||
let result = router.parse_response(input).unwrap();
|
||||
assert_eq!(result, Some("route1".to_string()));
|
||||
}
|
||||
|
|
@ -171,6 +171,18 @@ pub struct Message {
|
|||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for Message {
|
||||
fn default() -> Self {
|
||||
Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: None,
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub finish_reason: Option<String>,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,11 @@ use crate::api::open_ai::{
|
|||
ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Routing {
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Configuration {
|
||||
pub version: String,
|
||||
|
|
@ -19,6 +24,7 @@ pub struct Configuration {
|
|||
pub ratelimits: Option<Vec<Ratelimit>>,
|
||||
pub tracing: Option<Tracing>,
|
||||
pub mode: Option<GatewayMode>,
|
||||
pub routing: Option<Routing>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
|
|
@ -166,6 +172,7 @@ pub struct LlmProvider {
|
|||
pub endpoint: Option<String>,
|
||||
pub port: Option<u16>,
|
||||
pub rate_limits: Option<LlmRatelimit>,
|
||||
pub usage: Option<String>,
|
||||
}
|
||||
|
||||
impl Display for LlmProvider {
|
||||
|
|
|
|||
|
|
@ -11,7 +11,8 @@ pub const MODEL_SERVER_NAME: &str = "model_server";
|
|||
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
|
||||
pub const MESSAGES_KEY: &str = "messages";
|
||||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
pub const CHAT_COMPLETIONS_PATH: [&str; 2] = ["/v1/chat/completions", "/openai/v1/chat/completions"];
|
||||
pub const CHAT_COMPLETIONS_PATH: [&str; 2] =
|
||||
["/v1/chat/completions", "/openai/v1/chat/completions"];
|
||||
pub const HEALTHZ_PATH: &str = "/healthz";
|
||||
pub const X_ARCH_STATE_HEADER: &str = "x-arch-state";
|
||||
pub const X_ARCH_API_RESPONSE: &str = "x-arch-api-response-message";
|
||||
|
|
@ -27,3 +28,4 @@ pub const HALLUCINATION_TEMPLATE: &str =
|
|||
"It seems I'm missing some information. Could you provide the following details ";
|
||||
pub const OTEL_COLLECTOR_HTTP: &str = "opentelemetry_collector_http";
|
||||
pub const OTEL_POST_PATH: &str = "/v1/traces";
|
||||
pub const LLM_ROUTE_HEADER: &str = "x-arch-llm-route";
|
||||
|
|
|
|||
|
|
@ -11,3 +11,4 @@ pub mod routing;
|
|||
pub mod stats;
|
||||
pub mod tokenizer;
|
||||
pub mod tracing;
|
||||
pub mod utils;
|
||||
|
|
|
|||
7
crates/common/src/utils.rs
Normal file
7
crates/common/src/utils.rs
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
pub fn shorten_string(s: &str) -> String {
|
||||
if s.len() > 80 {
|
||||
format!("{}...", &s[..80])
|
||||
} else {
|
||||
s.to_string()
|
||||
}
|
||||
}
|
||||
|
|
@ -228,6 +228,7 @@ impl HttpContext for StreamContext {
|
|||
stream: None,
|
||||
port: None,
|
||||
rate_limits: None,
|
||||
usage: None,
|
||||
}));
|
||||
} else {
|
||||
self.select_llm_provider();
|
||||
|
|
@ -316,10 +317,6 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
// remove metadata from the request body
|
||||
//TODO: move this to prompt gateway
|
||||
// deserialized_body.metadata = None;
|
||||
// delete model key from message array
|
||||
for message in deserialized_body.messages.iter_mut() {
|
||||
message.model = None;
|
||||
}
|
||||
|
|
@ -342,24 +339,22 @@ impl HttpContext for StreamContext {
|
|||
};
|
||||
|
||||
let model_requested = deserialized_body.model.clone();
|
||||
if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
|
||||
deserialized_body.model = match model_name {
|
||||
Some(model_name) => model_name.clone(),
|
||||
None => {
|
||||
if use_agent_orchestrator {
|
||||
"agent_orchestrator".to_string()
|
||||
} else {
|
||||
self.send_server_error(
|
||||
ServerError::BadRequest {
|
||||
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
deserialized_body.model = match model_name {
|
||||
Some(model_name) => model_name.clone(),
|
||||
None => {
|
||||
if use_agent_orchestrator {
|
||||
"agent_orchestrator".to_string()
|
||||
} else {
|
||||
self.send_server_error(
|
||||
ServerError::BadRequest {
|
||||
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
info!(
|
||||
"on_http_request_body: provider: {}, model requested: {}, model selected: {}",
|
||||
|
|
|
|||
|
|
@ -489,7 +489,6 @@ fn llm_gateway_override_model_name() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
|
|
|
|||
|
|
@ -777,18 +777,19 @@ impl StreamContext {
|
|||
|
||||
fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool {
|
||||
let content = model_server_response
|
||||
.choices.first()
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref());
|
||||
|
||||
let content_has_value = content.is_some() && !content.unwrap().is_empty();
|
||||
|
||||
let tool_calls = model_server_response
|
||||
.choices.first()
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.tool_calls.as_ref());
|
||||
|
||||
// intent was matched if content has some value or tool_calls is empty
|
||||
|
||||
|
||||
content_has_value || (tool_calls.is_some() && !tool_calls.unwrap().is_empty())
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue