initial commit

This commit is contained in:
Adil Hafeez 2025-05-08 14:07:23 -07:00
parent 1f95fac4af
commit 1d19f0c2f7
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
36 changed files with 3003 additions and 109 deletions

1642
crates/Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,3 +1,3 @@
[workspace]
resolver = "2"
members = ["llm_gateway", "prompt_gateway", "common"]
members = ["llm_gateway", "prompt_gateway", "common", "whitestaff"]

View file

@ -166,6 +166,7 @@ pub struct LlmProvider {
pub endpoint: Option<String>,
pub port: Option<u16>,
pub rate_limits: Option<LlmRatelimit>,
pub usage: Option<String>,
}
impl Display for LlmProvider {

View file

@ -27,3 +27,4 @@ pub const HALLUCINATION_TEMPLATE: &str =
"It seems I'm missing some information. Could you provide the following details ";
pub const OTEL_COLLECTOR_HTTP: &str = "opentelemetry_collector_http";
pub const OTEL_POST_PATH: &str = "/v1/traces";
pub const LLM_ROUTE_HEADER: &str = "x-arch-llm-route";

View file

@ -2,6 +2,7 @@ use std::rc::Rc;
use crate::{configuration, llm_providers::LlmProviders};
use configuration::LlmProvider;
use log::info;
use rand::{seq::IteratorRandom, thread_rng};
#[derive(Debug)]
@ -29,6 +30,8 @@ pub fn get_llm_provider(
ProviderHint::Name(name) => llm_providers.get(&name),
});
info!("selected provider: maybe_provider: {:?}", maybe_provider);
if let Some(provider) = maybe_provider {
return provider;
}

View file

@ -3,6 +3,8 @@ use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod filter_context;
mod llm_routing;
mod llm_routing_consts;
mod metrics;
mod stream_context;

View file

@ -0,0 +1,106 @@
// use std::rc::Rc;
// use std::time::Duration;
// use common::api::open_ai::{ChatCompletionsRequest, Message};
// use common::configuration::LlmProvider;
// use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, ARCH_UPSTREAM_HOST_HEADER};
// use common::errors::ServerError;
// use common::http::{CallArgs, Client};
// use log::{info, warn};
// use proxy_wasm::traits::HttpContext;
// use proxy_wasm::types::Action;
// use crate::llm_routing_consts::SYSTEM_PROMPT;
// use crate::stream_context::{CallContext, StreamContext};
// pub trait Routing {
// fn route(&self) -> Action;
// }
// impl Routing for StreamContext {
// fn route(&self) -> Action {
// let usage_based_providers = self
// .llm_providers
// .iter()
// .filter(|(_, provider)| provider.usage.is_some())
// .map(|(_, provider)| provider.clone())
// .collect::<Vec<Rc<LlmProvider>>>();
// info!(
// "usage based providers found: {}",
// usage_based_providers
// .iter()
// .map(|provider| provider.name.clone())
// .collect::<Vec<String>>()
// .join(", ")
// );
// if usage_based_providers.is_empty() {
// self.set_http_request_body(
// 0,
// self.request_size.unwrap(),
// self.request_body.as_ref().unwrap().as_bytes(),
// );
// return Action::Continue;
// }
// let llm_routes_str = r#"- name: gpt-4o
// description: simple requests, basic fact retrieval, easy to answer
// - name: o4-mini()
// description: complex reasoning problem, require multi step answer"#;
// let chat_completions_request_messages_str =
// serde_json::to_string(&self.chat_completion_request.as_ref().unwrap().messages)
// .expect("failed to serialize llm routing request messages");
// let system_prompt_formatted = SYSTEM_PROMPT
// .replace("{routes}", llm_routes_str)
// .replace("{conversation}", &chat_completions_request_messages_str);
// let message = Message {
// role: "user".to_string(),
// content: Some(system_prompt_formatted),
// model: None,
// tool_calls: None,
// tool_call_id: None,
// };
// let llm_routing_request = ChatCompletionsRequest {
// model: "cotran2/llama-1b-4-26".to_string(),
// messages: vec![message],
// tools: None,
// stream: false,
// stream_options: None,
// metadata: None,
// };
// let llm_routing_request_str = serde_json::to_string(&llm_routing_request)
// .expect("failed to serialize llm routing request");
// let headers = vec![
// (":method", "POST"),
// (ARCH_UPSTREAM_HOST_HEADER, "gcp_hosted_outer_llm"),
// (":path", "/v1/chat/completions"),
// (":authority", "gcp_hosted_outer_llm"),
// ("content-type", "application/json"),
// ("x-envoy-max-retries", "3"),
// ("x-envoy-upstream-rq-timeout-ms", "5000"),
// ];
// let call_args = CallArgs::new(
// ARCH_INTERNAL_CLUSTER_NAME,
// "/v1/chat/completions",
// headers,
// llm_routing_request_str.as_bytes().into(),
// vec![],
// Duration::from_secs(5),
// );
// if let Err(e) = self.http_call(call_args, CallContext {}) {
// warn!("failed to call LLM provider: {}", e);
// self.send_server_error(ServerError::HttpDispatch(e), None);
// }
// Action::Pause
// }
// }

View file

@ -0,0 +1,32 @@
// pub const SYSTEM_PROMPT: &str = r#"
// You are an advanced Routing Assistant designed to select the optimal route based on user requests.
// Your task is to analyze conversations and match them to the most appropriate predefined route.
// Review the available routes config:
// # ROUTES CONFIG START
// {routes}
// # ROUTES CONFIG END
// Examine the following conversation between a user and an assistant:
// # CONVERSATION START
// {conversation}
// # CONVERSATION END
// Your goal is to identify the most appropriate route that matches the user's LATEST intent. Follow these steps:
// 1. Carefully read and analyze the provided conversation, focusing on the user's latest request and the conversation scenario.
// 2. Check if the user's request and scenario matches any of the routes in the routing configuration (focus on the description).
// 3. Find the route that best matches.
// 4. Use context clues from the entire conversation to determine the best fit.
// 5. Return the best match possible. You only response the name of the route that best matches the user's request, use the exact name in the routes config.
// 6. If no route relatively close to matches the user's latest intent or user last message is thank you or greeting, return an empty route ''.
// # OUTPUT FORMAT
// Your final output must follow this JSON format:
// {
// "route": "route_name" # The matched route name, or empty string '' if no match
// }
// Based on your analysis, provide only the JSON object as your final output with no additional text, explanations, or whitespace.
// "#;

View file

@ -9,9 +9,10 @@ use common::consts::{
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
};
use common::errors::ServerError;
use common::http::Client;
use common::llm_providers::LlmProviders;
use common::ratelimit::Header;
use common::stats::{IncrementingMetric, RecordingMetric};
use common::stats::{Gauge, IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use http::StatusCode;
@ -19,12 +20,16 @@ use log::{debug, info, warn};
use proxy_wasm::hostcalls::get_current_time;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::collections::VecDeque;
use std::cell::RefCell;
use std::collections::{HashMap, VecDeque};
use std::num::NonZero;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[derive(Debug)]
pub struct CallContext {}
pub struct StreamContext {
context_id: u32,
metrics: Rc<Metrics>,
@ -32,7 +37,7 @@ pub struct StreamContext {
streaming_response: bool,
response_tokens: usize,
is_chat_completions_request: bool,
llm_providers: Rc<LlmProviders>,
pub(crate) llm_providers: Rc<LlmProviders>,
llm_provider: Option<Rc<LlmProvider>>,
request_id: Option<String>,
start_time: SystemTime,
@ -43,6 +48,10 @@ pub struct StreamContext {
user_message: Option<Message>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
overrides: Rc<Option<Overrides>>,
pub(crate) request_body: Option<String>,
pub(crate) request_size: Option<usize>,
pub(crate) chat_completion_request: Option<ChatCompletionsRequest>,
callouts: RefCell<HashMap<u32, CallContext>>,
}
impl StreamContext {
@ -71,8 +80,13 @@ impl StreamContext {
user_message: None,
traces_queue,
request_body_sent_time: None,
request_body: None,
request_size: None,
chat_completion_request: None,
callouts: RefCell::new(HashMap::new()),
}
}
fn llm_provider(&self) -> &LlmProvider {
self.llm_provider
.as_ref()
@ -156,7 +170,7 @@ impl StreamContext {
});
}
fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
pub fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
warn!("server error occurred: {}", error);
self.send_http_response(
override_status_code
@ -228,6 +242,7 @@ impl HttpContext for StreamContext {
stream: None,
port: None,
rate_limits: None,
usage: None,
}));
} else {
self.select_llm_provider();
@ -321,7 +336,7 @@ impl HttpContext for StreamContext {
// deserialized_body.metadata = None;
// delete model key from message array
for message in deserialized_body.messages.iter_mut() {
message.model = None;
// message.model = None;
}
self.user_message = deserialized_body
@ -331,43 +346,45 @@ impl HttpContext for StreamContext {
.last()
.cloned();
let model_name = match self.llm_provider.as_ref() {
Some(llm_provider) => llm_provider.model.as_ref(),
None => None,
};
// let model_name = match self.llm_provider.as_ref() {
// Some(llm_provider) => llm_provider.model.as_ref(),
// None => None,
// };
let use_agent_orchestrator = match self.overrides.as_ref() {
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
None => false,
};
// let use_agent_orchestrator = match self.overrides.as_ref() {
// Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
// None => false,
// };
let model_requested = deserialized_body.model.clone();
if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
deserialized_body.model = match model_name {
Some(model_name) => model_name.clone(),
None => {
if use_agent_orchestrator {
"agent_orchestrator".to_string()
} else {
self.send_server_error(
ServerError::BadRequest {
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
},
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
}
}
}
}
// if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
// deserialized_body.model = match model_name {
// Some(model_name) => model_name.clone(),
// None => {
// if use_agent_orchestrator {
// "agent_orchestrator".to_string()
// } else {
// self.send_server_error(
// ServerError::BadRequest {
// why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
// },
// Some(StatusCode::BAD_REQUEST),
// );
// return Action::Continue;
// }
// }
// }
// }
info!(
"on_http_request_body: provider: {}, model requested: {}, model selected: {}",
"on_http_request_body: provider: {}, model requested: {}, model selected: {:?}",
self.llm_provider().name,
model_requested,
model_name.unwrap_or(&"None".to_string()),
self.llm_provider().model,
);
deserialized_body.model = self.llm_provider().model.clone().unwrap();
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
debug!(
@ -404,7 +421,12 @@ impl HttpContext for StreamContext {
self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes());
Action::Continue
self.chat_completion_request = Some(deserialized_body);
self.request_body = Some(chat_completion_request_str);
self.request_size = Some(body_size);
return Action::Continue;
// return self.route();
}
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
@ -665,4 +687,50 @@ fn current_time_ns() -> u128 {
.as_nanos()
}
impl Context for StreamContext {}
impl Context for StreamContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
debug!(
"on_http_call_response [S={}] token_id={} num_headers={} body_size={} num_trailers={}",
self.context_id, token_id, _num_headers, body_size, _num_trailers
);
let _callout_data = self
.callouts
.borrow_mut()
.remove(&token_id)
.expect("invalid token_id");
let body = self
.get_http_call_response_body(0, body_size)
.unwrap_or_default();
info!(
"on_http_call_response: response body: {}",
String::from_utf8_lossy(&body)
);
self.set_http_request_body(
0,
self.request_size.unwrap(),
self.request_body.as_ref().unwrap().as_bytes(),
);
}
}
impl Client for StreamContext {
type CallContext = CallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
&self.callouts
}
fn active_http_calls(&self) -> &Gauge {
&self.metrics.active_http_calls
}
}

View file

@ -103,7 +103,7 @@ impl StreamContext {
}
}
pub fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
pub (crate) fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
self.send_http_response(
override_status_code
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)

View file

@ -0,0 +1,29 @@
[package]
name = "whitestaff"
version = "0.1.0"
edition = "2021"
[dependencies]
bytes = "1.10.1"
common = { version = "0.1.0", path = "../common" }
eventsource-client = "0.15.0"
eventsource-stream = "0.2.3"
futures = "0.3.31"
http-body-util = "0.1.3"
hyper = { version="1.6.0", features = ["full"] }
hyper-util = "0.1.11"
opentelemetry = "0.29.1"
opentelemetry-http = "0.29.0"
opentelemetry-otlp = "0.29.0"
opentelemetry-stdout = "0.29.0"
opentelemetry_sdk = "0.29.0"
reqwest = "0.12.15"
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"
serde_yaml = "0.9.34"
thiserror = "2.0.12"
tokio = { version = "1.44.2", features = ["full"] }
tokio-stream = "0.1.17"
tracing = "0.1.41"
tracing-opentelemetry = "0.30.0"
tracing-subscriber = { version="0.3.19", features = ["env-filter", "fmt"] }

View file

@ -0,0 +1,32 @@
pub const SYSTEM_PROMPT_Z: &str = r#"
You are an advanced Routing Assistant designed to select the optimal route based on user requests.
Your task is to analyze conversations and match them to the most appropriate predefined route.
Review the available routes config:
# ROUTES CONFIG START
{routes}
# ROUTES CONFIG END
Examine the following conversation between a user and an assistant:
# CONVERSATION START
{conversation}
# CONVERSATION END
Your goal is to identify the most appropriate route that matches the user's LATEST intent. Follow these steps:
1. Carefully read and analyze the provided conversation, focusing on the user's latest request and the conversation scenario.
2. Check if the user's request and scenario matches any of the routes in the routing configuration (focus on the description).
3. Find the route that best matches.
4. Use context clues from the entire conversation to determine the best fit.
5. Return the best match possible. You only response the name of the route that best matches the user's request, use the exact name in the routes config.
6. If no route relatively close to matches the user's latest intent or user last message is thank you or greeting, return an empty route ''.
# OUTPUT FORMAT
Your final output must follow this JSON format:
{
"route": "route_name" # The matched route name, or empty string '' if no match
}
Based on your analysis, provide only the JSON object as your final output with no additional text, explanations, or whitespace.
"#;

View file

@ -0,0 +1,3 @@
mod consts2;
mod router;
mod types;

View file

@ -0,0 +1,406 @@
use bytes::Bytes;
use common::api::open_ai::{ChatCompletionsRequest, ChatCompletionsResponse, Message};
use common::configuration::{Configuration, LlmProvider};
use common::consts::{ARCH_PROVIDER_HINT_HEADER, USER_ROLE};
use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full};
use hyper::body::{Body, Incoming};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{header, Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use opentelemetry::global::BoxedTracer;
use opentelemetry::trace::FutureExt;
use opentelemetry::{
global,
trace::{SpanKind, Tracer},
Context,
};
use opentelemetry_http::HeaderExtractor;
use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::SdkTracerProvider};
use opentelemetry_stdout::SpanExporter;
use types::types::LlmRouterResponse;
use std::env;
use std::sync::{Arc, OnceLock};
use tokio::net::TcpListener;
use tracing::info;
use tracing_subscriber::EnvFilter;
mod consts2;
use consts2::SYSTEM_PROMPT_Z;
mod types;
const BIND_ADDRESS: &str = "0.0.0.0:9091";
fn get_tracer() -> &'static BoxedTracer {
static TRACER: OnceLock<BoxedTracer> = OnceLock::new();
TRACER.get_or_init(|| global::tracer("archgw/whitestaff"))
}
// Utility function to extract the context from the incoming request headers
fn extract_context_from_request(req: &Request<Incoming>) -> Context {
global::get_text_map_propagator(|propagator| {
propagator.extract(&HeaderExtractor(req.headers()))
})
}
fn init_tracer() -> SdkTracerProvider {
global::set_text_map_propagator(TraceContextPropagator::new());
// Install stdout exporter pipeline to be able to retrieve the collected spans.
// For the demonstration, use `Sampler::AlwaysOn` sampler to sample all traces.
let provider = SdkTracerProvider::builder()
.with_simple_exporter(SpanExporter::default())
.build();
global::set_tracer_provider(provider.clone());
provider
}
fn empty() -> BoxBody<Bytes, hyper::Error> {
Empty::<Bytes>::new()
.map_err(|never| match never {})
.boxed()
}
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}
fn shorten_string(s: &str) -> String {
if s.len() > 80 {
format!("{}...", &s[..80])
} else {
s.to_string()
}
}
async fn chat_completion(
req: Request<hyper::body::Incoming>,
arch_config: Arc<Configuration>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let max = req.body().size_hint().upper().unwrap_or(u64::MAX);
if max > 1024 * 1024 {
let error_msg = format!("Request body too large: {} bytes", max);
let mut too_large = Response::new(full(error_msg));
*too_large.status_mut() = StatusCode::PAYLOAD_TOO_LARGE;
return Ok(too_large);
}
let mut request_headers = req.headers().clone();
info!(
"Request headers: {}",
request_headers
.iter()
.map(|(k, v)| format!("{}: {}", k, v.to_str().unwrap_or_default()))
.collect::<Vec<String>>()
.join(", ")
);
let chat_request_bytes = req.collect().await?.to_bytes();
let chat_completion_request: ChatCompletionsRequest =
match serde_json::from_slice(&chat_request_bytes) {
Ok(request) => request,
Err(err) => {
let err_msg = format!("Failed to parse request body: {}", err);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
};
info!(
"Received request: {}",
&serde_json::to_string(&chat_completion_request).unwrap()
);
let llm_providers: Vec<LlmProvider> = chat_completion_request
.metadata
.as_ref()
.and_then(|metadata| metadata.get("llm_providers"))
.and_then(|providers| serde_json::from_str::<Vec<LlmProvider>>(providers).ok())
.unwrap_or_default();
info!(
"llm_providers from request: {}...",
shorten_string(&serde_json::to_string(&llm_providers).unwrap())
);
let llm_router_with_usage = arch_config
.llm_providers
.iter()
.filter(|provider| provider.usage.is_some()).cloned()
.collect::<Vec<LlmProvider>>();
// convert the llm_providers to yaml string but only include name and usage
let llm_providers_yaml = llm_router_with_usage
.iter()
.map(|provider| {
format!(
"- name: {}()\n description: {}",
provider.name,
provider.usage.as_ref().unwrap_or(&"".to_string())
)
})
.collect::<Vec<String>>()
.join("\n");
info!(
"llm_providers from config: {}...",
shorten_string(&llm_providers_yaml.replace("\n", "\\n"))
);
let message = SYSTEM_PROMPT_Z
.replace("{routes}", &llm_providers_yaml)
.replace(
"{conversation}",
&serde_json::to_string_pretty(&chat_completion_request.messages).unwrap(),
);
let router_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: "cotran2/llama-1b-4-26".to_string(),
messages: vec![Message {
content: Some(message),
role: USER_ROLE.to_string(),
model: None,
tool_calls: None,
tool_call_id: None,
}],
tools: None,
stream: false,
stream_options: None,
metadata: None,
};
info!(
"router_request: {}...",
shorten_string(&serde_json::to_string(&router_request).unwrap())
);
let trace_parent = request_headers
.iter()
.find(|(ty, _)| ty.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default());
let mut llm_route_request_headers = header::HeaderMap::new();
llm_route_request_headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
// attach traceparent header to the llm router request
if let Some(trace_parent) = trace_parent {
llm_route_request_headers.insert(
header::HeaderName::from_static("traceparent"),
header::HeaderValue::from_str(trace_parent).unwrap(),
);
}
llm_route_request_headers.insert(
header::HeaderName::from_static("host"),
header::HeaderValue::from_static("router_model_host"),
);
let res = match reqwest::Client::new()
.post("http://localhost:9090/v1/chat/completions")
.headers(llm_route_request_headers)
.body(serde_json::to_string(&router_request).unwrap())
.send()
.await
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let body = match res.text().await {
Ok(body) => body,
Err(err) => {
let err_msg = format!("Failed to read response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) {
Ok(response) => response,
Err(err) => {
let err_msg = format!("Failed to parse response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
info!(
"chat_completion_response: {}",
shorten_string(&serde_json::to_string(&chat_completion_response).unwrap())
);
let router_resp = chat_completion_response.choices[0]
.message
.content
.as_ref()
.unwrap();
let router_resp_fixed = router_resp.replace("'", "\"");
let router_response: LlmRouterResponse = match serde_json::from_str(router_resp_fixed.as_str())
{
Ok(response) => response,
Err(err) => {
let err_msg = format!("Failed to parse response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
info!(
"router_response json: {}",
serde_json::to_string(&router_response).unwrap()
);
let selecter_llm = router_response
.route
.map(|route| route.strip_suffix("()").unwrap_or_default().to_string())
.unwrap_or_default();
if selecter_llm.is_empty() {
let conversation = &serde_json::to_string(&chat_completion_request.messages).unwrap();
info!(
"no route selected for conversation: {}",
shorten_string(conversation)
);
}
info!("selecter_llm: {}", selecter_llm);
if let Some(trace_parent) = trace_parent {
request_headers.insert(
header::HeaderName::from_static("traceparent"),
header::HeaderValue::from_str(trace_parent).unwrap(),
);
}
if !selecter_llm.is_empty() {
request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(&selecter_llm).unwrap(),
);
}
let llm_response = match reqwest::Client::new()
.post("http://localhost:12000/v1/chat/completions")
.headers(request_headers)
.body(chat_request_bytes)
.send()
.await
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let body = match llm_response.text().await {
Ok(body) => body,
Err(err) => {
let err_msg = format!("Failed to read response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let mut ok_response = Response::new(full(body));
*ok_response.status_mut() = StatusCode::OK;
Ok(ok_response)
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let _tracer_provider = init_tracer();
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
)
.init();
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
//loading arch_config.yaml file
let arch_config_path =
env::var("ARCH_CONFIG_PATH").unwrap_or_else(|_| "arch_config.yaml".to_string());
info!("Loading arch_config.yaml from {}", arch_config_path);
let arch_config =
std::fs::read_to_string(&arch_config_path).expect("Failed to read arch_config.yaml");
let config: Configuration =
serde_yaml::from_str(&arch_config).expect("Failed to parse arch_config.yaml");
let arch_config = Arc::new(config);
info!(
"arch_config: {:?}",
shorten_string(&serde_json::to_string(arch_config.as_ref()).unwrap())
);
info!("Listening on http://{}", bind_address);
let listener = TcpListener::bind(bind_address).await?;
loop {
let (stream, _) = listener.accept().await?;
let peer_addr = stream.peer_addr()?;
let io = TokioIo::new(stream);
let arch_config = Arc::clone(&arch_config);
let service = service_fn(move |req| {
let arch_config = Arc::clone(&arch_config);
let parent_cx = extract_context_from_request(&req);
info!("parent_cx: {:?}", parent_cx);
let tracer = get_tracer();
let _span = tracer
.span_builder("chat_completion")
.with_kind(SpanKind::Server)
.start_with_context(tracer, &parent_cx);
async move {
match (req.method(), req.uri().path()) {
(&Method::POST, "/v1/chat/completions") => {
info!(
"config: {:?}",
shorten_string(
&serde_json::to_string(&arch_config.llm_providers).unwrap()
)
);
chat_completion(req, arch_config)
.with_context(parent_cx)
.await
}
_ => {
let mut not_found = Response::new(empty());
*not_found.status_mut() = StatusCode::NOT_FOUND;
Ok(not_found)
}
}
}
});
tokio::task::spawn(async move {
info!("Accepted connection from {:?}", peer_addr);
if let Err(err) = http1::Builder::new()
// .serve_connection(io, service_fn(chat_completion))
.serve_connection(io, service)
.await
{
info!("Error serving connection: {:?}", err);
}
});
}
}

View file

@ -0,0 +1,32 @@
pub const ARCH_ROUTER_V1_SYSTEM_PROMPT: &str = r#"
You are an advanced Routing Assistant designed to select the optimal route based on user requests.
Your task is to analyze conversations and match them to the most appropriate predefined route.
Review the available routes config:
# ROUTES CONFIG START
{routes}
# ROUTES CONFIG END
Examine the following conversation between a user and an assistant:
# CONVERSATION START
{conversation}
# CONVERSATION END
Your goal is to identify the most appropriate route that matches the user's LATEST intent. Follow these steps:
1. Carefully read and analyze the provided conversation, focusing on the user's latest request and the conversation scenario.
2. Check if the user's request and scenario matches any of the routes in the routing configuration (focus on the description).
3. Find the route that best matches.
4. Use context clues from the entire conversation to determine the best fit.
5. Return the best match possible. You only response the name of the route that best matches the user's request, use the exact name in the routes config.
6. If no route relatively close to matches the user's latest intent or user last message is thank you or greeting, return an empty route ''.
# OUTPUT FORMAT
Your final output must follow this JSON format:
{
"route": "route_name" # The matched route name, or empty string '' if no match
}
Based on your analysis, provide only the JSON object as your final output with no additional text, explanations, or whitespace.
"#;

View file

@ -0,0 +1,164 @@
use common::{
api::open_ai::{ChatCompletionsRequest, ChatCompletionsResponse, Message},
configuration::LlmProvider,
consts::USER_ROLE,
};
use hyper::header;
use thiserror::Error;
use tracing::info;
use crate::{router::consts::ARCH_ROUTER_V1_SYSTEM_PROMPT, types::types::LlmRouterResponse};
// Domain Service example
pub struct RouterService {
providers: Vec<LlmProvider>,
providers_with_usage: Vec<LlmProvider>,
router_url: String,
client: reqwest::Client,
llm_providers_with_usage_yaml: String,
}
#[derive(Debug, Error)]
pub enum RoutingError {
#[error("Failed to send request: {0}")]
RequestError(#[from] reqwest::Error),
#[error("Failed to parse JSON: {0}")]
JsonError(#[from] serde_json::Error),
}
type Result<T> = std::result::Result<T, RoutingError>;
impl RouterService {
pub fn new(providers: Vec<LlmProvider>, router_url: String) -> Self {
let providers_with_usage = providers
.iter()
.filter(|provider| provider.usage.is_some())
.cloned()
.collect::<Vec<LlmProvider>>();
// convert the llm_providers to yaml string but only include name and usage
let llm_providers_with_usage_yaml = providers_with_usage
.iter()
.map(|provider| {
format!(
"- name: {}()\n description: {}",
provider.name,
provider.usage.as_ref().unwrap_or(&"".to_string())
)
})
.collect::<Vec<String>>()
.join("\n");
info!(
"llm_providers from config with usage: {}...",
&llm_providers_with_usage_yaml.replace("\n", "\\n")
);
RouterService {
providers,
providers_with_usage,
router_url,
llm_providers_with_usage_yaml,
client: reqwest::Client::new(),
}
}
pub async fn determine_route(
&self,
chat_completion_request: &ChatCompletionsRequest,
) -> Result<String> {
let message = ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", &self.llm_providers_with_usage_yaml)
.replace(
"{conversation}",
&serde_json::to_string_pretty(&chat_completion_request.messages).unwrap(),
);
let router_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: "cotran2/llama-1b-4-26".to_string(),
messages: vec![Message {
content: Some(message),
role: USER_ROLE.to_string(),
model: None,
tool_calls: None,
tool_call_id: None,
}],
tools: None,
stream: false,
stream_options: None,
metadata: None,
};
info!(
"router_request: {}",
&serde_json::to_string(&router_request).unwrap()
);
// let trace_parent = request_headers
// .iter()
// .find(|(ty, _)| ty.as_str() == "traceparent")
// .map(|(_, value)| value.to_str().unwrap_or_default());
let mut llm_route_request_headers = header::HeaderMap::new();
llm_route_request_headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
// // attach traceparent header to the llm router request
// if let Some(trace_parent) = trace_parent {
// llm_route_request_headers.insert(
// header::HeaderName::from_static("traceparent"),
// header::HeaderValue::from_str(trace_parent).unwrap(),
// );
// }
llm_route_request_headers.insert(
header::HeaderName::from_static("host"),
header::HeaderValue::from_static("router_model_host"),
);
let res = reqwest::Client::new()
.post(&self.router_url)
.headers(llm_route_request_headers)
.body(serde_json::to_string(&router_request).unwrap())
.send()
.await?;
let body = res.text().await?;
let chat_completion_response: ChatCompletionsResponse = serde_json::from_str(&body)?;
info!(
"chat_completion_response: {}",
&serde_json::to_string(&chat_completion_response).unwrap()
);
let router_resp = chat_completion_response.choices[0]
.message
.content
.as_ref()
.unwrap();
let router_resp_fixed = router_resp.replace("'", "\"");
let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?;
info!(
"router_response json: {}",
serde_json::to_string(&router_response).unwrap()
);
let selecter_llm = router_response
.route
.map(|route| route.strip_suffix("()").unwrap_or_default().to_string())
.unwrap_or_default();
if selecter_llm.is_empty() {
let conversation = &serde_json::to_string(&chat_completion_request.messages).unwrap();
info!("no route selected for conversation: {}", conversation);
}
info!("selecter_llm: {}", selecter_llm);
Ok(self.router_url.clone())
}
}

View file

@ -0,0 +1,2 @@
pub mod llm_router;
mod consts;

View file

@ -0,0 +1 @@
pub mod types;

View file

@ -0,0 +1,6 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmRouterResponse {
pub route: Option<String>,
}