From f13fc76a4a1ca060d0446935c9b9efa02fa00de6 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Mon, 12 May 2025 12:55:59 -0700 Subject: [PATCH] more updates --- arch/arch_config_schema.yaml | 6 + crates/Cargo.lock | 16 + crates/common/src/configuration.rs | 6 + crates/common/src/lib.rs | 1 + crates/common/src/utils.rs | 7 + crates/whitestaff/Cargo.toml | 7 +- crates/whitestaff/src/consts2.rs | 32 -- .../src/handlers/chat_completions.rs | 151 +++++++++ crates/whitestaff/src/handlers/mod.rs | 1 + crates/whitestaff/src/lib.rs | 5 +- crates/whitestaff/src/main.rs | 316 ++---------------- crates/whitestaff/src/router/consts.rs | 32 -- crates/whitestaff/src/router/llm_router.rs | 133 +++----- crates/whitestaff/src/router/mod.rs | 3 +- crates/whitestaff/src/router/router_model.rs | 15 + .../whitestaff/src/router/router_model_v1.rs | 140 ++++++++ crates/whitestaff/src/types/mod.rs | 1 - crates/whitestaff/src/types/types.rs | 6 - .../preference_based_routing/arch_config.yaml | 7 +- 19 files changed, 431 insertions(+), 454 deletions(-) create mode 100644 crates/common/src/utils.rs delete mode 100644 crates/whitestaff/src/consts2.rs create mode 100644 crates/whitestaff/src/handlers/chat_completions.rs create mode 100644 crates/whitestaff/src/handlers/mod.rs delete mode 100644 crates/whitestaff/src/router/consts.rs create mode 100644 crates/whitestaff/src/router/router_model.rs create mode 100644 crates/whitestaff/src/router/router_model_v1.rs delete mode 100644 crates/whitestaff/src/types/mod.rs delete mode 100644 crates/whitestaff/src/types/types.rs diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index 467d6d22..a72db695 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -227,6 +227,12 @@ properties: enum: - llm - prompt + routing: + type: object + properties: + model: + type: string + additionalProperties: false prompt_guards: type: object properties: diff --git a/crates/Cargo.lock b/crates/Cargo.lock index fffb59a0..d8885a21 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -2200,11 +2200,13 @@ dependencies = [ "system-configuration", "tokio", "tokio-native-tls", + "tokio-util", "tower", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "windows-registry", ] @@ -3281,6 +3283,19 @@ dependencies = [ "wasmparser 0.219.1", ] +[[package]] +name = "wasm-streams" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e072d4e72f700fb3443d8fe94a39315df013eef1104903cdb0a2abd322bbecd" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "wasmparser" version = "0.212.0" @@ -3623,6 +3638,7 @@ dependencies = [ "eventsource-client", "eventsource-stream", "futures", + "futures-util", "http-body-util", "hyper 1.6.0", "hyper-util", diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 4518e368..2fb0238f 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -6,6 +6,11 @@ use crate::api::open_ai::{ ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType, }; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Routing { + pub model: String, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Configuration { pub version: String, @@ -19,6 +24,7 @@ pub struct Configuration { pub ratelimits: Option>, pub tracing: Option, pub mode: Option, + pub routing: Option, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index 32549893..76c368f1 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -11,3 +11,4 @@ pub mod routing; pub mod stats; pub mod tokenizer; pub mod tracing; +pub mod utils; diff --git a/crates/common/src/utils.rs b/crates/common/src/utils.rs new file mode 100644 index 00000000..fa31d166 --- /dev/null +++ b/crates/common/src/utils.rs @@ -0,0 +1,7 @@ +pub fn shorten_string(s: &str) -> String { + if s.len() > 80 { + format!("{}...", &s[..80]) + } else { + s.to_string() + } +} diff --git a/crates/whitestaff/Cargo.toml b/crates/whitestaff/Cargo.toml index ca9b0ca0..375bacef 100644 --- a/crates/whitestaff/Cargo.toml +++ b/crates/whitestaff/Cargo.toml @@ -9,15 +9,16 @@ common = { version = "0.1.0", path = "../common" } eventsource-client = "0.15.0" eventsource-stream = "0.2.3" futures = "0.3.31" +futures-util = "0.3.31" http-body-util = "0.1.3" -hyper = { version="1.6.0", features = ["full"] } +hyper = { version = "1.6.0", features = ["full"] } hyper-util = "0.1.11" opentelemetry = "0.29.1" opentelemetry-http = "0.29.0" opentelemetry-otlp = "0.29.0" opentelemetry-stdout = "0.29.0" opentelemetry_sdk = "0.29.0" -reqwest = "0.12.15" +reqwest = { version = "0.12.15", features = ["stream"] } serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" serde_yaml = "0.9.34" @@ -26,4 +27,4 @@ tokio = { version = "1.44.2", features = ["full"] } tokio-stream = "0.1.17" tracing = "0.1.41" tracing-opentelemetry = "0.30.0" -tracing-subscriber = { version="0.3.19", features = ["env-filter", "fmt"] } +tracing-subscriber = { version = "0.3.19", features = ["env-filter", "fmt"] } diff --git a/crates/whitestaff/src/consts2.rs b/crates/whitestaff/src/consts2.rs deleted file mode 100644 index 51764b71..00000000 --- a/crates/whitestaff/src/consts2.rs +++ /dev/null @@ -1,32 +0,0 @@ -pub const SYSTEM_PROMPT_Z: &str = r#" -You are an advanced Routing Assistant designed to select the optimal route based on user requests. -Your task is to analyze conversations and match them to the most appropriate predefined route. -Review the available routes config: - -# ROUTES CONFIG START -{routes} -# ROUTES CONFIG END - -Examine the following conversation between a user and an assistant: - -# CONVERSATION START -{conversation} -# CONVERSATION END - -Your goal is to identify the most appropriate route that matches the user's LATEST intent. Follow these steps: - -1. Carefully read and analyze the provided conversation, focusing on the user's latest request and the conversation scenario. -2. Check if the user's request and scenario matches any of the routes in the routing configuration (focus on the description). -3. Find the route that best matches. -4. Use context clues from the entire conversation to determine the best fit. -5. Return the best match possible. You only response the name of the route that best matches the user's request, use the exact name in the routes config. -6. If no route relatively close to matches the user's latest intent or user last message is thank you or greeting, return an empty route ''. - -# OUTPUT FORMAT -Your final output must follow this JSON format: -{ - "route": "route_name" # The matched route name, or empty string '' if no match -} - -Based on your analysis, provide only the JSON object as your final output with no additional text, explanations, or whitespace. -"#; diff --git a/crates/whitestaff/src/handlers/chat_completions.rs b/crates/whitestaff/src/handlers/chat_completions.rs new file mode 100644 index 00000000..49848b37 --- /dev/null +++ b/crates/whitestaff/src/handlers/chat_completions.rs @@ -0,0 +1,151 @@ +use std::sync::Arc; + +use bytes::Bytes; +use common::api::open_ai::ChatCompletionsRequest; +use common::consts::ARCH_PROVIDER_HINT_HEADER; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyExt, Full}; +use hyper::body::Body; +use hyper::header; +use hyper::{Request, Response, StatusCode}; +use tracing::info; + +use crate::router::llm_router::RouterService; + +fn full>(chunk: T) -> BoxBody { + Full::new(chunk.into()) + .map_err(|never| match never {}) + .boxed() +} + +pub async fn chat_completion( + request: Request, + router_service: Arc, + llm_provider_endpoint: String, +) -> Result>, hyper::Error> { + let max = request.body().size_hint().upper().unwrap_or(u64::MAX); + if max > 1024 * 1024 { + let error_msg = format!("Request body too large: {} bytes", max); + let mut too_large = Response::new(full(error_msg)); + *too_large.status_mut() = StatusCode::PAYLOAD_TOO_LARGE; + return Ok(too_large); + } + + let mut request_headers = request.headers().clone(); + + info!( + "Request headers: {}", + request_headers + .iter() + .map(|(k, v)| format!("{}: {}", k, v.to_str().unwrap_or_default())) + .collect::>() + .join(", ") + ); + let chat_request_bytes = request.collect().await?.to_bytes(); + let chat_completion_request: ChatCompletionsRequest = + match serde_json::from_slice(&chat_request_bytes) { + Ok(request) => request, + Err(err) => { + let err_msg = format!("Failed to parse request body: {}", err); + let mut bad_request = Response::new(full(err_msg)); + *bad_request.status_mut() = StatusCode::BAD_REQUEST; + return Ok(bad_request); + } + }; + + info!( + "Received request: {}", + &serde_json::to_string(&chat_completion_request).unwrap() + ); + + let trace_parent = request_headers + .iter() + .find(|(ty, _)| ty.as_str() == "traceparent") + .map(|(_, value)| value.to_str().unwrap_or_default().to_string()); + + let selected_llm = match router_service + .determine_route(&chat_completion_request.messages, trace_parent.clone()) + .await + { + Ok(route) => route, + Err(err) => { + let err_msg = format!("Failed to determine route: {}", err); + let mut internal_error = Response::new(full(err_msg)); + *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + return Ok(internal_error); + } + }; + + info!( + "sending request to llm provider: {} with llm model: {:?}", + llm_provider_endpoint, selected_llm + ); + + if let Some(trace_parent) = trace_parent { + request_headers.insert( + header::HeaderName::from_static("traceparent"), + header::HeaderValue::from_str(&trace_parent).unwrap(), + ); + } + + if let Some(selected_llm) = selected_llm { + request_headers.insert( + ARCH_PROVIDER_HINT_HEADER, + header::HeaderValue::from_str(&selected_llm).unwrap(), + ); + } + + let llm_response = match reqwest::Client::new() + .post(llm_provider_endpoint) + .headers(request_headers) + .body(chat_request_bytes) + .send() + .await + { + Ok(res) => res, + Err(err) => { + let err_msg = format!("Failed to send request: {}", err); + let mut internal_error = Response::new(full(err_msg)); + *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + return Ok(internal_error); + } + }; + + // if chat_completion_request.stream { + // let mut byte_stream = llm_response.bytes_stream(); + + // while let Some(item) = byte_stream.next().await { + // let item = match item { + // Ok(item) => item, + // Err(err) => { + // let err_msg = format!("Failed to read stream: {}", err); + // let mut internal_error = Response::new(full(err_msg)); + // *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + // return Ok(internal_error); + // } + // }; + + // info!("Received chunk: {:?}", item); + // } + + // let mut ok_response = Response::new(empty()); + // *ok_response.status_mut() = StatusCode::OK; + + // return Ok(ok_response); + // } else { + let body = match llm_response.text().await { + Ok(body) => body, + Err(err) => { + let err_msg = format!("Failed to read response: {}", err); + let mut internal_error = Response::new(full(err_msg)); + *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + return Ok(internal_error); + } + }; + + let mut ok_response = Response::new(full(body)); + *ok_response.status_mut() = StatusCode::OK; + + Ok(ok_response) + // } +} diff --git a/crates/whitestaff/src/handlers/mod.rs b/crates/whitestaff/src/handlers/mod.rs new file mode 100644 index 00000000..dcf38982 --- /dev/null +++ b/crates/whitestaff/src/handlers/mod.rs @@ -0,0 +1 @@ +pub mod chat_completions; diff --git a/crates/whitestaff/src/lib.rs b/crates/whitestaff/src/lib.rs index dfef3bc2..0591d7d0 100644 --- a/crates/whitestaff/src/lib.rs +++ b/crates/whitestaff/src/lib.rs @@ -1,3 +1,2 @@ -mod consts2; -mod router; -mod types; +pub mod handlers; +pub mod router; diff --git a/crates/whitestaff/src/main.rs b/crates/whitestaff/src/main.rs index 171b8e8e..96e4aacd 100644 --- a/crates/whitestaff/src/main.rs +++ b/crates/whitestaff/src/main.rs @@ -1,12 +1,11 @@ use bytes::Bytes; -use common::api::open_ai::{ChatCompletionsRequest, ChatCompletionsResponse, Message}; -use common::configuration::{Configuration, LlmProvider}; -use common::consts::{ARCH_PROVIDER_HINT_HEADER, USER_ROLE}; -use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full}; -use hyper::body::{Body, Incoming}; +use common::configuration::Configuration; +use common::utils::shorten_string; +use http_body_util::{combinators::BoxBody, BodyExt, Empty}; +use hyper::body::Incoming; use hyper::server::conn::http1; use hyper::service::service_fn; -use hyper::{header, Method, Request, Response, StatusCode}; +use hyper::{Method, Request, Response, StatusCode}; use hyper_util::rt::TokioIo; use opentelemetry::global::BoxedTracer; use opentelemetry::trace::FutureExt; @@ -18,16 +17,15 @@ use opentelemetry::{ use opentelemetry_http::HeaderExtractor; use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::SdkTracerProvider}; use opentelemetry_stdout::SpanExporter; -use types::types::LlmRouterResponse; -use std::env; use std::sync::{Arc, OnceLock}; +use std::{env, fs}; use tokio::net::TcpListener; use tracing::info; use tracing_subscriber::EnvFilter; +use whitestaff::handlers::chat_completions::chat_completion; +use whitestaff::router::llm_router::RouterService; -mod consts2; -use consts2::SYSTEM_PROMPT_Z; -mod types; +pub mod router; const BIND_ADDRESS: &str = "0.0.0.0:9091"; @@ -61,270 +59,6 @@ fn empty() -> BoxBody { .boxed() } -fn full>(chunk: T) -> BoxBody { - Full::new(chunk.into()) - .map_err(|never| match never {}) - .boxed() -} - -fn shorten_string(s: &str) -> String { - if s.len() > 80 { - format!("{}...", &s[..80]) - } else { - s.to_string() - } -} - -async fn chat_completion( - req: Request, - arch_config: Arc, -) -> Result>, hyper::Error> { - let max = req.body().size_hint().upper().unwrap_or(u64::MAX); - if max > 1024 * 1024 { - let error_msg = format!("Request body too large: {} bytes", max); - let mut too_large = Response::new(full(error_msg)); - *too_large.status_mut() = StatusCode::PAYLOAD_TOO_LARGE; - return Ok(too_large); - } - - let mut request_headers = req.headers().clone(); - - info!( - "Request headers: {}", - request_headers - .iter() - .map(|(k, v)| format!("{}: {}", k, v.to_str().unwrap_or_default())) - .collect::>() - .join(", ") - ); - let chat_request_bytes = req.collect().await?.to_bytes(); - let chat_completion_request: ChatCompletionsRequest = - match serde_json::from_slice(&chat_request_bytes) { - Ok(request) => request, - Err(err) => { - let err_msg = format!("Failed to parse request body: {}", err); - let mut bad_request = Response::new(full(err_msg)); - *bad_request.status_mut() = StatusCode::BAD_REQUEST; - return Ok(bad_request); - } - }; - - info!( - "Received request: {}", - &serde_json::to_string(&chat_completion_request).unwrap() - ); - - let llm_providers: Vec = chat_completion_request - .metadata - .as_ref() - .and_then(|metadata| metadata.get("llm_providers")) - .and_then(|providers| serde_json::from_str::>(providers).ok()) - .unwrap_or_default(); - - info!( - "llm_providers from request: {}...", - shorten_string(&serde_json::to_string(&llm_providers).unwrap()) - ); - - let llm_router_with_usage = arch_config - .llm_providers - .iter() - .filter(|provider| provider.usage.is_some()).cloned() - .collect::>(); - - // convert the llm_providers to yaml string but only include name and usage - let llm_providers_yaml = llm_router_with_usage - .iter() - .map(|provider| { - format!( - "- name: {}()\n description: {}", - provider.name, - provider.usage.as_ref().unwrap_or(&"".to_string()) - ) - }) - .collect::>() - .join("\n"); - - info!( - "llm_providers from config: {}...", - shorten_string(&llm_providers_yaml.replace("\n", "\\n")) - ); - - let message = SYSTEM_PROMPT_Z - .replace("{routes}", &llm_providers_yaml) - .replace( - "{conversation}", - &serde_json::to_string_pretty(&chat_completion_request.messages).unwrap(), - ); - - let router_request: ChatCompletionsRequest = ChatCompletionsRequest { - model: "cotran2/llama-1b-4-26".to_string(), - messages: vec![Message { - content: Some(message), - role: USER_ROLE.to_string(), - model: None, - tool_calls: None, - tool_call_id: None, - }], - tools: None, - stream: false, - stream_options: None, - metadata: None, - }; - - info!( - "router_request: {}...", - shorten_string(&serde_json::to_string(&router_request).unwrap()) - ); - - let trace_parent = request_headers - .iter() - .find(|(ty, _)| ty.as_str() == "traceparent") - .map(|(_, value)| value.to_str().unwrap_or_default()); - - let mut llm_route_request_headers = header::HeaderMap::new(); - llm_route_request_headers.insert( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ); - - // attach traceparent header to the llm router request - if let Some(trace_parent) = trace_parent { - llm_route_request_headers.insert( - header::HeaderName::from_static("traceparent"), - header::HeaderValue::from_str(trace_parent).unwrap(), - ); - } - - llm_route_request_headers.insert( - header::HeaderName::from_static("host"), - header::HeaderValue::from_static("router_model_host"), - ); - - let res = match reqwest::Client::new() - .post("http://localhost:9090/v1/chat/completions") - .headers(llm_route_request_headers) - .body(serde_json::to_string(&router_request).unwrap()) - .send() - .await - { - Ok(res) => res, - Err(err) => { - let err_msg = format!("Failed to send request: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - return Ok(internal_error); - } - }; - - let body = match res.text().await { - Ok(body) => body, - Err(err) => { - let err_msg = format!("Failed to read response: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - return Ok(internal_error); - } - }; - - let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) { - Ok(response) => response, - Err(err) => { - let err_msg = format!("Failed to parse response: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - return Ok(internal_error); - } - }; - - info!( - "chat_completion_response: {}", - shorten_string(&serde_json::to_string(&chat_completion_response).unwrap()) - ); - - let router_resp = chat_completion_response.choices[0] - .message - .content - .as_ref() - .unwrap(); - let router_resp_fixed = router_resp.replace("'", "\""); - let router_response: LlmRouterResponse = match serde_json::from_str(router_resp_fixed.as_str()) - { - Ok(response) => response, - Err(err) => { - let err_msg = format!("Failed to parse response: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - return Ok(internal_error); - } - }; - - info!( - "router_response json: {}", - serde_json::to_string(&router_response).unwrap() - ); - - let selecter_llm = router_response - .route - .map(|route| route.strip_suffix("()").unwrap_or_default().to_string()) - .unwrap_or_default(); - - if selecter_llm.is_empty() { - let conversation = &serde_json::to_string(&chat_completion_request.messages).unwrap(); - info!( - "no route selected for conversation: {}", - shorten_string(conversation) - ); - } - - info!("selecter_llm: {}", selecter_llm); - - if let Some(trace_parent) = trace_parent { - request_headers.insert( - header::HeaderName::from_static("traceparent"), - header::HeaderValue::from_str(trace_parent).unwrap(), - ); - } - - if !selecter_llm.is_empty() { - request_headers.insert( - ARCH_PROVIDER_HINT_HEADER, - header::HeaderValue::from_str(&selecter_llm).unwrap(), - ); - } - - let llm_response = match reqwest::Client::new() - .post("http://localhost:12000/v1/chat/completions") - .headers(request_headers) - .body(chat_request_bytes) - .send() - .await - { - Ok(res) => res, - Err(err) => { - let err_msg = format!("Failed to send request: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - return Ok(internal_error); - } - }; - - let body = match llm_response.text().await { - Ok(body) => body, - Err(err) => { - let err_msg = format!("Failed to read response: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - return Ok(internal_error); - } - }; - - let mut ok_response = Response::new(full(body)); - *ok_response.status_mut() = StatusCode::OK; - - Ok(ok_response) -} - #[tokio::main] async fn main() -> Result<(), Box> { let _tracer_provider = init_tracer(); @@ -340,11 +74,15 @@ async fn main() -> Result<(), Box> { let arch_config_path = env::var("ARCH_CONFIG_PATH").unwrap_or_else(|_| "arch_config.yaml".to_string()); info!("Loading arch_config.yaml from {}", arch_config_path); - let arch_config = - std::fs::read_to_string(&arch_config_path).expect("Failed to read arch_config.yaml"); + + let config_contents = + fs::read_to_string(&arch_config_path).expect("Failed to read arch_config.yaml"); + let config: Configuration = - serde_yaml::from_str(&arch_config).expect("Failed to parse arch_config.yaml"); + serde_yaml::from_str(&config_contents).expect("Failed to parse arch_config.yaml"); + let arch_config = Arc::new(config); + info!( "arch_config: {:?}", shorten_string(&serde_json::to_string(arch_config.as_ref()).unwrap()) @@ -353,33 +91,35 @@ async fn main() -> Result<(), Box> { info!("Listening on http://{}", bind_address); let listener = TcpListener::bind(bind_address).await?; + let llm_provider_endpoint = "http://localhost:12000/v1/chat/completions"; + + let router_service: Arc = Arc::new(RouterService::new( + arch_config.llm_providers.clone(), + llm_provider_endpoint.to_string(), + arch_config.routing.as_ref().unwrap().model.clone(), + )); + loop { let (stream, _) = listener.accept().await?; let peer_addr = stream.peer_addr()?; let io = TokioIo::new(stream); - let arch_config = Arc::clone(&arch_config); + let router_service = Arc::clone(&router_service); let service = service_fn(move |req| { - let arch_config = Arc::clone(&arch_config); + let router_service = Arc::clone(&router_service); let parent_cx = extract_context_from_request(&req); info!("parent_cx: {:?}", parent_cx); let tracer = get_tracer(); let _span = tracer - .span_builder("chat_completion") + .span_builder("router_service") .with_kind(SpanKind::Server) .start_with_context(tracer, &parent_cx); async move { match (req.method(), req.uri().path()) { (&Method::POST, "/v1/chat/completions") => { - info!( - "config: {:?}", - shorten_string( - &serde_json::to_string(&arch_config.llm_providers).unwrap() - ) - ); - chat_completion(req, arch_config) + chat_completion(req, router_service, llm_provider_endpoint.to_string()) .with_context(parent_cx) .await } diff --git a/crates/whitestaff/src/router/consts.rs b/crates/whitestaff/src/router/consts.rs deleted file mode 100644 index 1128bf32..00000000 --- a/crates/whitestaff/src/router/consts.rs +++ /dev/null @@ -1,32 +0,0 @@ -pub const ARCH_ROUTER_V1_SYSTEM_PROMPT: &str = r#" -You are an advanced Routing Assistant designed to select the optimal route based on user requests. -Your task is to analyze conversations and match them to the most appropriate predefined route. -Review the available routes config: - -# ROUTES CONFIG START -{routes} -# ROUTES CONFIG END - -Examine the following conversation between a user and an assistant: - -# CONVERSATION START -{conversation} -# CONVERSATION END - -Your goal is to identify the most appropriate route that matches the user's LATEST intent. Follow these steps: - -1. Carefully read and analyze the provided conversation, focusing on the user's latest request and the conversation scenario. -2. Check if the user's request and scenario matches any of the routes in the routing configuration (focus on the description). -3. Find the route that best matches. -4. Use context clues from the entire conversation to determine the best fit. -5. Return the best match possible. You only response the name of the route that best matches the user's request, use the exact name in the routes config. -6. If no route relatively close to matches the user's latest intent or user last message is thank you or greeting, return an empty route ''. - -# OUTPUT FORMAT -Your final output must follow this JSON format: -{ - "route": "route_name" # The matched route name, or empty string '' if no match -} - -Based on your analysis, provide only the JSON object as your final output with no additional text, explanations, or whitespace. -"#; diff --git a/crates/whitestaff/src/router/llm_router.rs b/crates/whitestaff/src/router/llm_router.rs index beb6f85e..afb0a94a 100644 --- a/crates/whitestaff/src/router/llm_router.rs +++ b/crates/whitestaff/src/router/llm_router.rs @@ -1,35 +1,44 @@ +use std::sync::Arc; + use common::{ - api::open_ai::{ChatCompletionsRequest, ChatCompletionsResponse, Message}, + api::open_ai::{ChatCompletionsResponse, Message}, configuration::LlmProvider, - consts::USER_ROLE, + consts::ARCH_PROVIDER_HINT_HEADER, + utils::shorten_string, }; use hyper::header; use thiserror::Error; use tracing::info; -use crate::{router::consts::ARCH_ROUTER_V1_SYSTEM_PROMPT, types::types::LlmRouterResponse}; +use super::router_model::RouterModel; -// Domain Service example pub struct RouterService { - providers: Vec, - providers_with_usage: Vec, router_url: String, client: reqwest::Client, - llm_providers_with_usage_yaml: String, + router_model: Arc, + routing_model_name: String, } #[derive(Debug, Error)] pub enum RoutingError { #[error("Failed to send request: {0}")] RequestError(#[from] reqwest::Error), + #[error("Failed to parse JSON: {0}")] JsonError(#[from] serde_json::Error), + + #[error("Router model error: {0}")] + RouterModelError(#[from] super::router_model::RoutingModelError), } -type Result = std::result::Result; +pub type Result = std::result::Result; impl RouterService { - pub fn new(providers: Vec, router_url: String) -> Self { + pub fn new( + providers: Vec, + router_url: String, + routing_model_name: String, + ) -> Self { let providers_with_usage = providers .iter() .filter(|provider| provider.usage.is_some()) @@ -51,74 +60,54 @@ impl RouterService { info!( "llm_providers from config with usage: {}...", - &llm_providers_with_usage_yaml.replace("\n", "\\n") + shorten_string(&llm_providers_with_usage_yaml.replace("\n", "\\n")) ); + let router_model = Arc::new(super::router_model_v1::RouterModelV1::new( + llm_providers_with_usage_yaml.clone(), + routing_model_name.clone(), + )); + RouterService { - providers, - providers_with_usage, router_url, - llm_providers_with_usage_yaml, client: reqwest::Client::new(), + router_model, + routing_model_name, } } pub async fn determine_route( &self, - chat_completion_request: &ChatCompletionsRequest, - ) -> Result { - let message = ARCH_ROUTER_V1_SYSTEM_PROMPT - .replace("{routes}", &self.llm_providers_with_usage_yaml) - .replace( - "{conversation}", - &serde_json::to_string_pretty(&chat_completion_request.messages).unwrap(), - ); - - let router_request: ChatCompletionsRequest = ChatCompletionsRequest { - model: "cotran2/llama-1b-4-26".to_string(), - messages: vec![Message { - content: Some(message), - role: USER_ROLE.to_string(), - model: None, - tool_calls: None, - tool_call_id: None, - }], - tools: None, - stream: false, - stream_options: None, - metadata: None, - }; + messages: &[Message], + trace_parent: Option, + ) -> Result> { + let router_request = self.router_model.generate_request(messages); info!( "router_request: {}", - &serde_json::to_string(&router_request).unwrap() + shorten_string(&serde_json::to_string(&router_request).unwrap()) ); - // let trace_parent = request_headers - // .iter() - // .find(|(ty, _)| ty.as_str() == "traceparent") - // .map(|(_, value)| value.to_str().unwrap_or_default()); - let mut llm_route_request_headers = header::HeaderMap::new(); llm_route_request_headers.insert( header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"), ); - // // attach traceparent header to the llm router request - // if let Some(trace_parent) = trace_parent { - // llm_route_request_headers.insert( - // header::HeaderName::from_static("traceparent"), - // header::HeaderValue::from_str(trace_parent).unwrap(), - // ); - // } - llm_route_request_headers.insert( - header::HeaderName::from_static("host"), - header::HeaderValue::from_static("router_model_host"), + header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER), + header::HeaderValue::from_str(&self.routing_model_name).unwrap(), ); - let res = reqwest::Client::new() + if let Some(trace_parent) = trace_parent { + llm_route_request_headers.insert( + header::HeaderName::from_static("traceparent"), + header::HeaderValue::from_str(&trace_parent).unwrap(), + ); + } + + let res = self + .client .post(&self.router_url) .headers(llm_route_request_headers) .body(serde_json::to_string(&router_request).unwrap()) @@ -129,36 +118,14 @@ impl RouterService { let chat_completion_response: ChatCompletionsResponse = serde_json::from_str(&body)?; - info!( - "chat_completion_response: {}", - &serde_json::to_string(&chat_completion_response).unwrap() - ); + let selected_llm = self.router_model.parse_response( + chat_completion_response.choices[0] + .message + .content + .as_ref() + .unwrap(), + )?; - let router_resp = chat_completion_response.choices[0] - .message - .content - .as_ref() - .unwrap(); - let router_resp_fixed = router_resp.replace("'", "\""); - let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?; - - info!( - "router_response json: {}", - serde_json::to_string(&router_response).unwrap() - ); - - let selecter_llm = router_response - .route - .map(|route| route.strip_suffix("()").unwrap_or_default().to_string()) - .unwrap_or_default(); - - if selecter_llm.is_empty() { - let conversation = &serde_json::to_string(&chat_completion_request.messages).unwrap(); - info!("no route selected for conversation: {}", conversation); - } - - info!("selecter_llm: {}", selecter_llm); - - Ok(self.router_url.clone()) + Ok(selected_llm) } } diff --git a/crates/whitestaff/src/router/mod.rs b/crates/whitestaff/src/router/mod.rs index c9757892..e35ea731 100644 --- a/crates/whitestaff/src/router/mod.rs +++ b/crates/whitestaff/src/router/mod.rs @@ -1,2 +1,3 @@ pub mod llm_router; -mod consts; +pub mod router_model; +pub mod router_model_v1; diff --git a/crates/whitestaff/src/router/router_model.rs b/crates/whitestaff/src/router/router_model.rs new file mode 100644 index 00000000..e9f5e256 --- /dev/null +++ b/crates/whitestaff/src/router/router_model.rs @@ -0,0 +1,15 @@ +use common::api::open_ai::{ChatCompletionsRequest, Message}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum RoutingModelError { + #[error("Failed to parse JSON: {0}")] + JsonError(#[from] serde_json::Error), +} + +pub type Result = std::result::Result; + +pub trait RouterModel: Send + Sync { + fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest; + fn parse_response(&self, content: &str) -> Result>; +} diff --git a/crates/whitestaff/src/router/router_model_v1.rs b/crates/whitestaff/src/router/router_model_v1.rs new file mode 100644 index 00000000..6d1cb7fb --- /dev/null +++ b/crates/whitestaff/src/router/router_model_v1.rs @@ -0,0 +1,140 @@ +use common::{ + api::open_ai::{ChatCompletionsRequest, Message}, + consts::USER_ROLE, +}; +use serde::{Deserialize, Serialize}; +use tracing::info; + +use super::router_model::{RouterModel, RoutingModelError}; + +pub const ARCH_ROUTER_V1_SYSTEM_PROMPT: &str = r#" +You are an advanced Routing Assistant designed to select the optimal route based on user requests. +Your task is to analyze conversations and match them to the most appropriate predefined route. +Review the available routes config: + +# ROUTES CONFIG START +{routes} +# ROUTES CONFIG END + +Examine the following conversation between a user and an assistant: + +# CONVERSATION START +{conversation} +# CONVERSATION END + +Your goal is to identify the most appropriate route that matches the user's LATEST intent. Follow these steps: + +1. Carefully read and analyze the provided conversation, focusing on the user's latest request and the conversation scenario. +2. Check if the user's request and scenario matches any of the routes in the routing configuration (focus on the description). +3. Find the route that best matches. +4. Use context clues from the entire conversation to determine the best fit. +5. Return the best match possible. You only response the name of the route that best matches the user's request, use the exact name in the routes config. +6. If no route relatively close to matches the user's latest intent or user last message is thank you or greeting, return an empty route ''. + +# OUTPUT FORMAT +Your final output must follow this JSON format: +{ + "route": "route_name" # The matched route name, or empty string '' if no match +} + +Based on your analysis, provide only the JSON object as your final output with no additional text, explanations, or whitespace. +"#; + +pub type Result = std::result::Result; + +pub struct RouterModelV1 { + llm_providers_with_usage_yaml: String, + routing_model: String, +} + +impl RouterModelV1 { + pub fn new(llm_providers_with_usage_yaml: String, routing_model: String) -> Self { + RouterModelV1 { + llm_providers_with_usage_yaml, + routing_model, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct LlmRouterResponse { + pub route: Option, +} + +impl RouterModel for RouterModelV1 { + fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest { + let message = ARCH_ROUTER_V1_SYSTEM_PROMPT + .replace("{routes}", &self.llm_providers_with_usage_yaml) + .replace( + "{conversation}", + &serde_json::to_string_pretty(messages).unwrap(), + ); + + ChatCompletionsRequest { + model: self.routing_model.clone(), + messages: vec![Message { + content: Some(message), + role: USER_ROLE.to_string(), + model: None, + tool_calls: None, + tool_call_id: None, + }], + tools: None, + stream: false, + stream_options: None, + metadata: None, + } + } + + fn parse_response(&self, content: &str) -> Result> { + let router_resp_fixed = fix_json_response(content); + info!( + "router response (fixed): {}", + router_resp_fixed.replace("\n", "\\n") + ); + let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?; + + let selecter_llm = router_response + .route + .map(|route| route.strip_suffix("()").unwrap_or_default().to_string()) + .unwrap(); + + if selecter_llm.is_empty() { + return Ok(None); + } + + Ok(Some(selecter_llm)) + } +} + +fn fix_json_response(body: &str) -> String { + let mut updated_body = body.to_string(); + + updated_body = updated_body.replace("'", "\""); + + if updated_body.contains("\\n") { + updated_body = updated_body.replace("\\n", ""); + } + + if updated_body.starts_with("```json") { + updated_body = updated_body + .strip_prefix("```json") + .unwrap_or(&updated_body) + .to_string(); + } + + if updated_body.ends_with("```") { + updated_body = updated_body + .strip_suffix("```") + .unwrap_or(&updated_body) + .to_string(); + } + + updated_body +} + +impl std::fmt::Debug for dyn RouterModel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "RouterModel") + } +} diff --git a/crates/whitestaff/src/types/mod.rs b/crates/whitestaff/src/types/mod.rs deleted file mode 100644 index cd408564..00000000 --- a/crates/whitestaff/src/types/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod types; diff --git a/crates/whitestaff/src/types/types.rs b/crates/whitestaff/src/types/types.rs deleted file mode 100644 index 0e929413..00000000 --- a/crates/whitestaff/src/types/types.rs +++ /dev/null @@ -1,6 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LlmRouterResponse { - pub route: Option, -} diff --git a/demos/use_cases/preference_based_routing/arch_config.yaml b/demos/use_cases/preference_based_routing/arch_config.yaml index 9ca5a35a..d31c56dd 100644 --- a/demos/use_cases/preference_based_routing/arch_config.yaml +++ b/demos/use_cases/preference_based_routing/arch_config.yaml @@ -1,10 +1,7 @@ version: "0.1-beta" -endpoints: - gcp_hosted_outer_llm: - endpoint: 34.46.85.85:8000 - http_host: 34.46.85.85 - # endpoint: host.docker.internal:11223 +routing: + model: gpt-4o listeners: egress_traffic: