more updates

This commit is contained in:
Adil Hafeez 2025-05-12 12:55:59 -07:00
parent 1d19f0c2f7
commit f13fc76a4a
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
19 changed files with 431 additions and 454 deletions

View file

@ -227,6 +227,12 @@ properties:
enum:
- llm
- prompt
routing:
type: object
properties:
model:
type: string
additionalProperties: false
prompt_guards:
type: object
properties:

16
crates/Cargo.lock generated
View file

@ -2200,11 +2200,13 @@ dependencies = [
"system-configuration",
"tokio",
"tokio-native-tls",
"tokio-util",
"tower",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"windows-registry",
]
@ -3281,6 +3283,19 @@ dependencies = [
"wasmparser 0.219.1",
]
[[package]]
name = "wasm-streams"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e072d4e72f700fb3443d8fe94a39315df013eef1104903cdb0a2abd322bbecd"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "wasmparser"
version = "0.212.0"
@ -3623,6 +3638,7 @@ dependencies = [
"eventsource-client",
"eventsource-stream",
"futures",
"futures-util",
"http-body-util",
"hyper 1.6.0",
"hyper-util",

View file

@ -6,6 +6,11 @@ use crate::api::open_ai::{
ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Routing {
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
@ -19,6 +24,7 @@ pub struct Configuration {
pub ratelimits: Option<Vec<Ratelimit>>,
pub tracing: Option<Tracing>,
pub mode: Option<GatewayMode>,
pub routing: Option<Routing>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]

View file

@ -11,3 +11,4 @@ pub mod routing;
pub mod stats;
pub mod tokenizer;
pub mod tracing;
pub mod utils;

View file

@ -0,0 +1,7 @@
pub fn shorten_string(s: &str) -> String {
if s.len() > 80 {
format!("{}...", &s[..80])
} else {
s.to_string()
}
}

View file

@ -9,15 +9,16 @@ common = { version = "0.1.0", path = "../common" }
eventsource-client = "0.15.0"
eventsource-stream = "0.2.3"
futures = "0.3.31"
futures-util = "0.3.31"
http-body-util = "0.1.3"
hyper = { version="1.6.0", features = ["full"] }
hyper = { version = "1.6.0", features = ["full"] }
hyper-util = "0.1.11"
opentelemetry = "0.29.1"
opentelemetry-http = "0.29.0"
opentelemetry-otlp = "0.29.0"
opentelemetry-stdout = "0.29.0"
opentelemetry_sdk = "0.29.0"
reqwest = "0.12.15"
reqwest = { version = "0.12.15", features = ["stream"] }
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"
serde_yaml = "0.9.34"
@ -26,4 +27,4 @@ tokio = { version = "1.44.2", features = ["full"] }
tokio-stream = "0.1.17"
tracing = "0.1.41"
tracing-opentelemetry = "0.30.0"
tracing-subscriber = { version="0.3.19", features = ["env-filter", "fmt"] }
tracing-subscriber = { version = "0.3.19", features = ["env-filter", "fmt"] }

View file

@ -1,32 +0,0 @@
pub const SYSTEM_PROMPT_Z: &str = r#"
You are an advanced Routing Assistant designed to select the optimal route based on user requests.
Your task is to analyze conversations and match them to the most appropriate predefined route.
Review the available routes config:
# ROUTES CONFIG START
{routes}
# ROUTES CONFIG END
Examine the following conversation between a user and an assistant:
# CONVERSATION START
{conversation}
# CONVERSATION END
Your goal is to identify the most appropriate route that matches the user's LATEST intent. Follow these steps:
1. Carefully read and analyze the provided conversation, focusing on the user's latest request and the conversation scenario.
2. Check if the user's request and scenario matches any of the routes in the routing configuration (focus on the description).
3. Find the route that best matches.
4. Use context clues from the entire conversation to determine the best fit.
5. Return the best match possible. You only response the name of the route that best matches the user's request, use the exact name in the routes config.
6. If no route relatively close to matches the user's latest intent or user last message is thank you or greeting, return an empty route ''.
# OUTPUT FORMAT
Your final output must follow this JSON format:
{
"route": "route_name" # The matched route name, or empty string '' if no match
}
Based on your analysis, provide only the JSON object as your final output with no additional text, explanations, or whitespace.
"#;

View file

@ -0,0 +1,151 @@
use std::sync::Arc;
use bytes::Bytes;
use common::api::open_ai::ChatCompletionsRequest;
use common::consts::ARCH_PROVIDER_HINT_HEADER;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use hyper::body::Body;
use hyper::header;
use hyper::{Request, Response, StatusCode};
use tracing::info;
use crate::router::llm_router::RouterService;
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}
pub async fn chat_completion(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
llm_provider_endpoint: String,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let max = request.body().size_hint().upper().unwrap_or(u64::MAX);
if max > 1024 * 1024 {
let error_msg = format!("Request body too large: {} bytes", max);
let mut too_large = Response::new(full(error_msg));
*too_large.status_mut() = StatusCode::PAYLOAD_TOO_LARGE;
return Ok(too_large);
}
let mut request_headers = request.headers().clone();
info!(
"Request headers: {}",
request_headers
.iter()
.map(|(k, v)| format!("{}: {}", k, v.to_str().unwrap_or_default()))
.collect::<Vec<String>>()
.join(", ")
);
let chat_request_bytes = request.collect().await?.to_bytes();
let chat_completion_request: ChatCompletionsRequest =
match serde_json::from_slice(&chat_request_bytes) {
Ok(request) => request,
Err(err) => {
let err_msg = format!("Failed to parse request body: {}", err);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
};
info!(
"Received request: {}",
&serde_json::to_string(&chat_completion_request).unwrap()
);
let trace_parent = request_headers
.iter()
.find(|(ty, _)| ty.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
let selected_llm = match router_service
.determine_route(&chat_completion_request.messages, trace_parent.clone())
.await
{
Ok(route) => route,
Err(err) => {
let err_msg = format!("Failed to determine route: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
info!(
"sending request to llm provider: {} with llm model: {:?}",
llm_provider_endpoint, selected_llm
);
if let Some(trace_parent) = trace_parent {
request_headers.insert(
header::HeaderName::from_static("traceparent"),
header::HeaderValue::from_str(&trace_parent).unwrap(),
);
}
if let Some(selected_llm) = selected_llm {
request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(&selected_llm).unwrap(),
);
}
let llm_response = match reqwest::Client::new()
.post(llm_provider_endpoint)
.headers(request_headers)
.body(chat_request_bytes)
.send()
.await
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
// if chat_completion_request.stream {
// let mut byte_stream = llm_response.bytes_stream();
// while let Some(item) = byte_stream.next().await {
// let item = match item {
// Ok(item) => item,
// Err(err) => {
// let err_msg = format!("Failed to read stream: {}", err);
// let mut internal_error = Response::new(full(err_msg));
// *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
// return Ok(internal_error);
// }
// };
// info!("Received chunk: {:?}", item);
// }
// let mut ok_response = Response::new(empty());
// *ok_response.status_mut() = StatusCode::OK;
// return Ok(ok_response);
// } else {
let body = match llm_response.text().await {
Ok(body) => body,
Err(err) => {
let err_msg = format!("Failed to read response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let mut ok_response = Response::new(full(body));
*ok_response.status_mut() = StatusCode::OK;
Ok(ok_response)
// }
}

View file

@ -0,0 +1 @@
pub mod chat_completions;

View file

@ -1,3 +1,2 @@
mod consts2;
mod router;
mod types;
pub mod handlers;
pub mod router;

View file

@ -1,12 +1,11 @@
use bytes::Bytes;
use common::api::open_ai::{ChatCompletionsRequest, ChatCompletionsResponse, Message};
use common::configuration::{Configuration, LlmProvider};
use common::consts::{ARCH_PROVIDER_HINT_HEADER, USER_ROLE};
use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full};
use hyper::body::{Body, Incoming};
use common::configuration::Configuration;
use common::utils::shorten_string;
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{header, Method, Request, Response, StatusCode};
use hyper::{Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use opentelemetry::global::BoxedTracer;
use opentelemetry::trace::FutureExt;
@ -18,16 +17,15 @@ use opentelemetry::{
use opentelemetry_http::HeaderExtractor;
use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::SdkTracerProvider};
use opentelemetry_stdout::SpanExporter;
use types::types::LlmRouterResponse;
use std::env;
use std::sync::{Arc, OnceLock};
use std::{env, fs};
use tokio::net::TcpListener;
use tracing::info;
use tracing_subscriber::EnvFilter;
use whitestaff::handlers::chat_completions::chat_completion;
use whitestaff::router::llm_router::RouterService;
mod consts2;
use consts2::SYSTEM_PROMPT_Z;
mod types;
pub mod router;
const BIND_ADDRESS: &str = "0.0.0.0:9091";
@ -61,270 +59,6 @@ fn empty() -> BoxBody<Bytes, hyper::Error> {
.boxed()
}
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}
fn shorten_string(s: &str) -> String {
if s.len() > 80 {
format!("{}...", &s[..80])
} else {
s.to_string()
}
}
async fn chat_completion(
req: Request<hyper::body::Incoming>,
arch_config: Arc<Configuration>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let max = req.body().size_hint().upper().unwrap_or(u64::MAX);
if max > 1024 * 1024 {
let error_msg = format!("Request body too large: {} bytes", max);
let mut too_large = Response::new(full(error_msg));
*too_large.status_mut() = StatusCode::PAYLOAD_TOO_LARGE;
return Ok(too_large);
}
let mut request_headers = req.headers().clone();
info!(
"Request headers: {}",
request_headers
.iter()
.map(|(k, v)| format!("{}: {}", k, v.to_str().unwrap_or_default()))
.collect::<Vec<String>>()
.join(", ")
);
let chat_request_bytes = req.collect().await?.to_bytes();
let chat_completion_request: ChatCompletionsRequest =
match serde_json::from_slice(&chat_request_bytes) {
Ok(request) => request,
Err(err) => {
let err_msg = format!("Failed to parse request body: {}", err);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
};
info!(
"Received request: {}",
&serde_json::to_string(&chat_completion_request).unwrap()
);
let llm_providers: Vec<LlmProvider> = chat_completion_request
.metadata
.as_ref()
.and_then(|metadata| metadata.get("llm_providers"))
.and_then(|providers| serde_json::from_str::<Vec<LlmProvider>>(providers).ok())
.unwrap_or_default();
info!(
"llm_providers from request: {}...",
shorten_string(&serde_json::to_string(&llm_providers).unwrap())
);
let llm_router_with_usage = arch_config
.llm_providers
.iter()
.filter(|provider| provider.usage.is_some()).cloned()
.collect::<Vec<LlmProvider>>();
// convert the llm_providers to yaml string but only include name and usage
let llm_providers_yaml = llm_router_with_usage
.iter()
.map(|provider| {
format!(
"- name: {}()\n description: {}",
provider.name,
provider.usage.as_ref().unwrap_or(&"".to_string())
)
})
.collect::<Vec<String>>()
.join("\n");
info!(
"llm_providers from config: {}...",
shorten_string(&llm_providers_yaml.replace("\n", "\\n"))
);
let message = SYSTEM_PROMPT_Z
.replace("{routes}", &llm_providers_yaml)
.replace(
"{conversation}",
&serde_json::to_string_pretty(&chat_completion_request.messages).unwrap(),
);
let router_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: "cotran2/llama-1b-4-26".to_string(),
messages: vec![Message {
content: Some(message),
role: USER_ROLE.to_string(),
model: None,
tool_calls: None,
tool_call_id: None,
}],
tools: None,
stream: false,
stream_options: None,
metadata: None,
};
info!(
"router_request: {}...",
shorten_string(&serde_json::to_string(&router_request).unwrap())
);
let trace_parent = request_headers
.iter()
.find(|(ty, _)| ty.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default());
let mut llm_route_request_headers = header::HeaderMap::new();
llm_route_request_headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
// attach traceparent header to the llm router request
if let Some(trace_parent) = trace_parent {
llm_route_request_headers.insert(
header::HeaderName::from_static("traceparent"),
header::HeaderValue::from_str(trace_parent).unwrap(),
);
}
llm_route_request_headers.insert(
header::HeaderName::from_static("host"),
header::HeaderValue::from_static("router_model_host"),
);
let res = match reqwest::Client::new()
.post("http://localhost:9090/v1/chat/completions")
.headers(llm_route_request_headers)
.body(serde_json::to_string(&router_request).unwrap())
.send()
.await
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let body = match res.text().await {
Ok(body) => body,
Err(err) => {
let err_msg = format!("Failed to read response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) {
Ok(response) => response,
Err(err) => {
let err_msg = format!("Failed to parse response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
info!(
"chat_completion_response: {}",
shorten_string(&serde_json::to_string(&chat_completion_response).unwrap())
);
let router_resp = chat_completion_response.choices[0]
.message
.content
.as_ref()
.unwrap();
let router_resp_fixed = router_resp.replace("'", "\"");
let router_response: LlmRouterResponse = match serde_json::from_str(router_resp_fixed.as_str())
{
Ok(response) => response,
Err(err) => {
let err_msg = format!("Failed to parse response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
info!(
"router_response json: {}",
serde_json::to_string(&router_response).unwrap()
);
let selecter_llm = router_response
.route
.map(|route| route.strip_suffix("()").unwrap_or_default().to_string())
.unwrap_or_default();
if selecter_llm.is_empty() {
let conversation = &serde_json::to_string(&chat_completion_request.messages).unwrap();
info!(
"no route selected for conversation: {}",
shorten_string(conversation)
);
}
info!("selecter_llm: {}", selecter_llm);
if let Some(trace_parent) = trace_parent {
request_headers.insert(
header::HeaderName::from_static("traceparent"),
header::HeaderValue::from_str(trace_parent).unwrap(),
);
}
if !selecter_llm.is_empty() {
request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(&selecter_llm).unwrap(),
);
}
let llm_response = match reqwest::Client::new()
.post("http://localhost:12000/v1/chat/completions")
.headers(request_headers)
.body(chat_request_bytes)
.send()
.await
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let body = match llm_response.text().await {
Ok(body) => body,
Err(err) => {
let err_msg = format!("Failed to read response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let mut ok_response = Response::new(full(body));
*ok_response.status_mut() = StatusCode::OK;
Ok(ok_response)
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let _tracer_provider = init_tracer();
@ -340,11 +74,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let arch_config_path =
env::var("ARCH_CONFIG_PATH").unwrap_or_else(|_| "arch_config.yaml".to_string());
info!("Loading arch_config.yaml from {}", arch_config_path);
let arch_config =
std::fs::read_to_string(&arch_config_path).expect("Failed to read arch_config.yaml");
let config_contents =
fs::read_to_string(&arch_config_path).expect("Failed to read arch_config.yaml");
let config: Configuration =
serde_yaml::from_str(&arch_config).expect("Failed to parse arch_config.yaml");
serde_yaml::from_str(&config_contents).expect("Failed to parse arch_config.yaml");
let arch_config = Arc::new(config);
info!(
"arch_config: {:?}",
shorten_string(&serde_json::to_string(arch_config.as_ref()).unwrap())
@ -353,33 +91,35 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
info!("Listening on http://{}", bind_address);
let listener = TcpListener::bind(bind_address).await?;
let llm_provider_endpoint = "http://localhost:12000/v1/chat/completions";
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
arch_config.llm_providers.clone(),
llm_provider_endpoint.to_string(),
arch_config.routing.as_ref().unwrap().model.clone(),
));
loop {
let (stream, _) = listener.accept().await?;
let peer_addr = stream.peer_addr()?;
let io = TokioIo::new(stream);
let arch_config = Arc::clone(&arch_config);
let router_service = Arc::clone(&router_service);
let service = service_fn(move |req| {
let arch_config = Arc::clone(&arch_config);
let router_service = Arc::clone(&router_service);
let parent_cx = extract_context_from_request(&req);
info!("parent_cx: {:?}", parent_cx);
let tracer = get_tracer();
let _span = tracer
.span_builder("chat_completion")
.span_builder("router_service")
.with_kind(SpanKind::Server)
.start_with_context(tracer, &parent_cx);
async move {
match (req.method(), req.uri().path()) {
(&Method::POST, "/v1/chat/completions") => {
info!(
"config: {:?}",
shorten_string(
&serde_json::to_string(&arch_config.llm_providers).unwrap()
)
);
chat_completion(req, arch_config)
chat_completion(req, router_service, llm_provider_endpoint.to_string())
.with_context(parent_cx)
.await
}

View file

@ -1,32 +0,0 @@
pub const ARCH_ROUTER_V1_SYSTEM_PROMPT: &str = r#"
You are an advanced Routing Assistant designed to select the optimal route based on user requests.
Your task is to analyze conversations and match them to the most appropriate predefined route.
Review the available routes config:
# ROUTES CONFIG START
{routes}
# ROUTES CONFIG END
Examine the following conversation between a user and an assistant:
# CONVERSATION START
{conversation}
# CONVERSATION END
Your goal is to identify the most appropriate route that matches the user's LATEST intent. Follow these steps:
1. Carefully read and analyze the provided conversation, focusing on the user's latest request and the conversation scenario.
2. Check if the user's request and scenario matches any of the routes in the routing configuration (focus on the description).
3. Find the route that best matches.
4. Use context clues from the entire conversation to determine the best fit.
5. Return the best match possible. You only response the name of the route that best matches the user's request, use the exact name in the routes config.
6. If no route relatively close to matches the user's latest intent or user last message is thank you or greeting, return an empty route ''.
# OUTPUT FORMAT
Your final output must follow this JSON format:
{
"route": "route_name" # The matched route name, or empty string '' if no match
}
Based on your analysis, provide only the JSON object as your final output with no additional text, explanations, or whitespace.
"#;

View file

@ -1,35 +1,44 @@
use std::sync::Arc;
use common::{
api::open_ai::{ChatCompletionsRequest, ChatCompletionsResponse, Message},
api::open_ai::{ChatCompletionsResponse, Message},
configuration::LlmProvider,
consts::USER_ROLE,
consts::ARCH_PROVIDER_HINT_HEADER,
utils::shorten_string,
};
use hyper::header;
use thiserror::Error;
use tracing::info;
use crate::{router::consts::ARCH_ROUTER_V1_SYSTEM_PROMPT, types::types::LlmRouterResponse};
use super::router_model::RouterModel;
// Domain Service example
pub struct RouterService {
providers: Vec<LlmProvider>,
providers_with_usage: Vec<LlmProvider>,
router_url: String,
client: reqwest::Client,
llm_providers_with_usage_yaml: String,
router_model: Arc<dyn RouterModel>,
routing_model_name: String,
}
#[derive(Debug, Error)]
pub enum RoutingError {
#[error("Failed to send request: {0}")]
RequestError(#[from] reqwest::Error),
#[error("Failed to parse JSON: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Router model error: {0}")]
RouterModelError(#[from] super::router_model::RoutingModelError),
}
type Result<T> = std::result::Result<T, RoutingError>;
pub type Result<T> = std::result::Result<T, RoutingError>;
impl RouterService {
pub fn new(providers: Vec<LlmProvider>, router_url: String) -> Self {
pub fn new(
providers: Vec<LlmProvider>,
router_url: String,
routing_model_name: String,
) -> Self {
let providers_with_usage = providers
.iter()
.filter(|provider| provider.usage.is_some())
@ -51,74 +60,54 @@ impl RouterService {
info!(
"llm_providers from config with usage: {}...",
&llm_providers_with_usage_yaml.replace("\n", "\\n")
shorten_string(&llm_providers_with_usage_yaml.replace("\n", "\\n"))
);
let router_model = Arc::new(super::router_model_v1::RouterModelV1::new(
llm_providers_with_usage_yaml.clone(),
routing_model_name.clone(),
));
RouterService {
providers,
providers_with_usage,
router_url,
llm_providers_with_usage_yaml,
client: reqwest::Client::new(),
router_model,
routing_model_name,
}
}
pub async fn determine_route(
&self,
chat_completion_request: &ChatCompletionsRequest,
) -> Result<String> {
let message = ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", &self.llm_providers_with_usage_yaml)
.replace(
"{conversation}",
&serde_json::to_string_pretty(&chat_completion_request.messages).unwrap(),
);
let router_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: "cotran2/llama-1b-4-26".to_string(),
messages: vec![Message {
content: Some(message),
role: USER_ROLE.to_string(),
model: None,
tool_calls: None,
tool_call_id: None,
}],
tools: None,
stream: false,
stream_options: None,
metadata: None,
};
messages: &[Message],
trace_parent: Option<String>,
) -> Result<Option<String>> {
let router_request = self.router_model.generate_request(messages);
info!(
"router_request: {}",
&serde_json::to_string(&router_request).unwrap()
shorten_string(&serde_json::to_string(&router_request).unwrap())
);
// let trace_parent = request_headers
// .iter()
// .find(|(ty, _)| ty.as_str() == "traceparent")
// .map(|(_, value)| value.to_str().unwrap_or_default());
let mut llm_route_request_headers = header::HeaderMap::new();
llm_route_request_headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
// // attach traceparent header to the llm router request
// if let Some(trace_parent) = trace_parent {
// llm_route_request_headers.insert(
// header::HeaderName::from_static("traceparent"),
// header::HeaderValue::from_str(trace_parent).unwrap(),
// );
// }
llm_route_request_headers.insert(
header::HeaderName::from_static("host"),
header::HeaderValue::from_static("router_model_host"),
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
header::HeaderValue::from_str(&self.routing_model_name).unwrap(),
);
let res = reqwest::Client::new()
if let Some(trace_parent) = trace_parent {
llm_route_request_headers.insert(
header::HeaderName::from_static("traceparent"),
header::HeaderValue::from_str(&trace_parent).unwrap(),
);
}
let res = self
.client
.post(&self.router_url)
.headers(llm_route_request_headers)
.body(serde_json::to_string(&router_request).unwrap())
@ -129,36 +118,14 @@ impl RouterService {
let chat_completion_response: ChatCompletionsResponse = serde_json::from_str(&body)?;
info!(
"chat_completion_response: {}",
&serde_json::to_string(&chat_completion_response).unwrap()
);
let selected_llm = self.router_model.parse_response(
chat_completion_response.choices[0]
.message
.content
.as_ref()
.unwrap(),
)?;
let router_resp = chat_completion_response.choices[0]
.message
.content
.as_ref()
.unwrap();
let router_resp_fixed = router_resp.replace("'", "\"");
let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?;
info!(
"router_response json: {}",
serde_json::to_string(&router_response).unwrap()
);
let selecter_llm = router_response
.route
.map(|route| route.strip_suffix("()").unwrap_or_default().to_string())
.unwrap_or_default();
if selecter_llm.is_empty() {
let conversation = &serde_json::to_string(&chat_completion_request.messages).unwrap();
info!("no route selected for conversation: {}", conversation);
}
info!("selecter_llm: {}", selecter_llm);
Ok(self.router_url.clone())
Ok(selected_llm)
}
}

View file

@ -1,2 +1,3 @@
pub mod llm_router;
mod consts;
pub mod router_model;
pub mod router_model_v1;

View file

@ -0,0 +1,15 @@
use common::api::open_ai::{ChatCompletionsRequest, Message};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum RoutingModelError {
#[error("Failed to parse JSON: {0}")]
JsonError(#[from] serde_json::Error),
}
pub type Result<T> = std::result::Result<T, RoutingModelError>;
pub trait RouterModel: Send + Sync {
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest;
fn parse_response(&self, content: &str) -> Result<Option<String>>;
}

View file

@ -0,0 +1,140 @@
use common::{
api::open_ai::{ChatCompletionsRequest, Message},
consts::USER_ROLE,
};
use serde::{Deserialize, Serialize};
use tracing::info;
use super::router_model::{RouterModel, RoutingModelError};
pub const ARCH_ROUTER_V1_SYSTEM_PROMPT: &str = r#"
You are an advanced Routing Assistant designed to select the optimal route based on user requests.
Your task is to analyze conversations and match them to the most appropriate predefined route.
Review the available routes config:
# ROUTES CONFIG START
{routes}
# ROUTES CONFIG END
Examine the following conversation between a user and an assistant:
# CONVERSATION START
{conversation}
# CONVERSATION END
Your goal is to identify the most appropriate route that matches the user's LATEST intent. Follow these steps:
1. Carefully read and analyze the provided conversation, focusing on the user's latest request and the conversation scenario.
2. Check if the user's request and scenario matches any of the routes in the routing configuration (focus on the description).
3. Find the route that best matches.
4. Use context clues from the entire conversation to determine the best fit.
5. Return the best match possible. You only response the name of the route that best matches the user's request, use the exact name in the routes config.
6. If no route relatively close to matches the user's latest intent or user last message is thank you or greeting, return an empty route ''.
# OUTPUT FORMAT
Your final output must follow this JSON format:
{
"route": "route_name" # The matched route name, or empty string '' if no match
}
Based on your analysis, provide only the JSON object as your final output with no additional text, explanations, or whitespace.
"#;
pub type Result<T> = std::result::Result<T, RoutingModelError>;
pub struct RouterModelV1 {
llm_providers_with_usage_yaml: String,
routing_model: String,
}
impl RouterModelV1 {
pub fn new(llm_providers_with_usage_yaml: String, routing_model: String) -> Self {
RouterModelV1 {
llm_providers_with_usage_yaml,
routing_model,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LlmRouterResponse {
pub route: Option<String>,
}
impl RouterModel for RouterModelV1 {
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest {
let message = ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", &self.llm_providers_with_usage_yaml)
.replace(
"{conversation}",
&serde_json::to_string_pretty(messages).unwrap(),
);
ChatCompletionsRequest {
model: self.routing_model.clone(),
messages: vec![Message {
content: Some(message),
role: USER_ROLE.to_string(),
model: None,
tool_calls: None,
tool_call_id: None,
}],
tools: None,
stream: false,
stream_options: None,
metadata: None,
}
}
fn parse_response(&self, content: &str) -> Result<Option<String>> {
let router_resp_fixed = fix_json_response(content);
info!(
"router response (fixed): {}",
router_resp_fixed.replace("\n", "\\n")
);
let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?;
let selecter_llm = router_response
.route
.map(|route| route.strip_suffix("()").unwrap_or_default().to_string())
.unwrap();
if selecter_llm.is_empty() {
return Ok(None);
}
Ok(Some(selecter_llm))
}
}
fn fix_json_response(body: &str) -> String {
let mut updated_body = body.to_string();
updated_body = updated_body.replace("'", "\"");
if updated_body.contains("\\n") {
updated_body = updated_body.replace("\\n", "");
}
if updated_body.starts_with("```json") {
updated_body = updated_body
.strip_prefix("```json")
.unwrap_or(&updated_body)
.to_string();
}
if updated_body.ends_with("```") {
updated_body = updated_body
.strip_suffix("```")
.unwrap_or(&updated_body)
.to_string();
}
updated_body
}
impl std::fmt::Debug for dyn RouterModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RouterModel")
}
}

View file

@ -1 +0,0 @@
pub mod types;

View file

@ -1,6 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmRouterResponse {
pub route: Option<String>,
}

View file

@ -1,10 +1,7 @@
version: "0.1-beta"
endpoints:
gcp_hosted_outer_llm:
endpoint: 34.46.85.85:8000
http_host: 34.46.85.85
# endpoint: host.docker.internal:11223
routing:
model: gpt-4o
listeners:
egress_traffic: