initial commit

This commit is contained in:
Adil Hafeez 2025-05-08 14:07:23 -07:00
parent 1f95fac4af
commit 1d19f0c2f7
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
36 changed files with 3003 additions and 109 deletions

View file

@ -3,6 +3,8 @@ use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod filter_context;
mod llm_routing;
mod llm_routing_consts;
mod metrics;
mod stream_context;

View file

@ -0,0 +1,106 @@
// use std::rc::Rc;
// use std::time::Duration;
// use common::api::open_ai::{ChatCompletionsRequest, Message};
// use common::configuration::LlmProvider;
// use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, ARCH_UPSTREAM_HOST_HEADER};
// use common::errors::ServerError;
// use common::http::{CallArgs, Client};
// use log::{info, warn};
// use proxy_wasm::traits::HttpContext;
// use proxy_wasm::types::Action;
// use crate::llm_routing_consts::SYSTEM_PROMPT;
// use crate::stream_context::{CallContext, StreamContext};
// pub trait Routing {
// fn route(&self) -> Action;
// }
// impl Routing for StreamContext {
// fn route(&self) -> Action {
// let usage_based_providers = self
// .llm_providers
// .iter()
// .filter(|(_, provider)| provider.usage.is_some())
// .map(|(_, provider)| provider.clone())
// .collect::<Vec<Rc<LlmProvider>>>();
// info!(
// "usage based providers found: {}",
// usage_based_providers
// .iter()
// .map(|provider| provider.name.clone())
// .collect::<Vec<String>>()
// .join(", ")
// );
// if usage_based_providers.is_empty() {
// self.set_http_request_body(
// 0,
// self.request_size.unwrap(),
// self.request_body.as_ref().unwrap().as_bytes(),
// );
// return Action::Continue;
// }
// let llm_routes_str = r#"- name: gpt-4o
// description: simple requests, basic fact retrieval, easy to answer
// - name: o4-mini()
// description: complex reasoning problem, require multi step answer"#;
// let chat_completions_request_messages_str =
// serde_json::to_string(&self.chat_completion_request.as_ref().unwrap().messages)
// .expect("failed to serialize llm routing request messages");
// let system_prompt_formatted = SYSTEM_PROMPT
// .replace("{routes}", llm_routes_str)
// .replace("{conversation}", &chat_completions_request_messages_str);
// let message = Message {
// role: "user".to_string(),
// content: Some(system_prompt_formatted),
// model: None,
// tool_calls: None,
// tool_call_id: None,
// };
// let llm_routing_request = ChatCompletionsRequest {
// model: "cotran2/llama-1b-4-26".to_string(),
// messages: vec![message],
// tools: None,
// stream: false,
// stream_options: None,
// metadata: None,
// };
// let llm_routing_request_str = serde_json::to_string(&llm_routing_request)
// .expect("failed to serialize llm routing request");
// let headers = vec![
// (":method", "POST"),
// (ARCH_UPSTREAM_HOST_HEADER, "gcp_hosted_outer_llm"),
// (":path", "/v1/chat/completions"),
// (":authority", "gcp_hosted_outer_llm"),
// ("content-type", "application/json"),
// ("x-envoy-max-retries", "3"),
// ("x-envoy-upstream-rq-timeout-ms", "5000"),
// ];
// let call_args = CallArgs::new(
// ARCH_INTERNAL_CLUSTER_NAME,
// "/v1/chat/completions",
// headers,
// llm_routing_request_str.as_bytes().into(),
// vec![],
// Duration::from_secs(5),
// );
// if let Err(e) = self.http_call(call_args, CallContext {}) {
// warn!("failed to call LLM provider: {}", e);
// self.send_server_error(ServerError::HttpDispatch(e), None);
// }
// Action::Pause
// }
// }

View file

@ -0,0 +1,32 @@
// pub const SYSTEM_PROMPT: &str = r#"
// You are an advanced Routing Assistant designed to select the optimal route based on user requests.
// Your task is to analyze conversations and match them to the most appropriate predefined route.
// Review the available routes config:
// # ROUTES CONFIG START
// {routes}
// # ROUTES CONFIG END
// Examine the following conversation between a user and an assistant:
// # CONVERSATION START
// {conversation}
// # CONVERSATION END
// Your goal is to identify the most appropriate route that matches the user's LATEST intent. Follow these steps:
// 1. Carefully read and analyze the provided conversation, focusing on the user's latest request and the conversation scenario.
// 2. Check if the user's request and scenario matches any of the routes in the routing configuration (focus on the description).
// 3. Find the route that best matches.
// 4. Use context clues from the entire conversation to determine the best fit.
// 5. Return the best match possible. You only response the name of the route that best matches the user's request, use the exact name in the routes config.
// 6. If no route relatively close to matches the user's latest intent or user last message is thank you or greeting, return an empty route ''.
// # OUTPUT FORMAT
// Your final output must follow this JSON format:
// {
// "route": "route_name" # The matched route name, or empty string '' if no match
// }
// Based on your analysis, provide only the JSON object as your final output with no additional text, explanations, or whitespace.
// "#;

View file

@ -9,9 +9,10 @@ use common::consts::{
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
};
use common::errors::ServerError;
use common::http::Client;
use common::llm_providers::LlmProviders;
use common::ratelimit::Header;
use common::stats::{IncrementingMetric, RecordingMetric};
use common::stats::{Gauge, IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use http::StatusCode;
@ -19,12 +20,16 @@ use log::{debug, info, warn};
use proxy_wasm::hostcalls::get_current_time;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::collections::VecDeque;
use std::cell::RefCell;
use std::collections::{HashMap, VecDeque};
use std::num::NonZero;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[derive(Debug)]
pub struct CallContext {}
pub struct StreamContext {
context_id: u32,
metrics: Rc<Metrics>,
@ -32,7 +37,7 @@ pub struct StreamContext {
streaming_response: bool,
response_tokens: usize,
is_chat_completions_request: bool,
llm_providers: Rc<LlmProviders>,
pub(crate) llm_providers: Rc<LlmProviders>,
llm_provider: Option<Rc<LlmProvider>>,
request_id: Option<String>,
start_time: SystemTime,
@ -43,6 +48,10 @@ pub struct StreamContext {
user_message: Option<Message>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
overrides: Rc<Option<Overrides>>,
pub(crate) request_body: Option<String>,
pub(crate) request_size: Option<usize>,
pub(crate) chat_completion_request: Option<ChatCompletionsRequest>,
callouts: RefCell<HashMap<u32, CallContext>>,
}
impl StreamContext {
@ -71,8 +80,13 @@ impl StreamContext {
user_message: None,
traces_queue,
request_body_sent_time: None,
request_body: None,
request_size: None,
chat_completion_request: None,
callouts: RefCell::new(HashMap::new()),
}
}
fn llm_provider(&self) -> &LlmProvider {
self.llm_provider
.as_ref()
@ -156,7 +170,7 @@ impl StreamContext {
});
}
fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
pub fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
warn!("server error occurred: {}", error);
self.send_http_response(
override_status_code
@ -228,6 +242,7 @@ impl HttpContext for StreamContext {
stream: None,
port: None,
rate_limits: None,
usage: None,
}));
} else {
self.select_llm_provider();
@ -321,7 +336,7 @@ impl HttpContext for StreamContext {
// deserialized_body.metadata = None;
// delete model key from message array
for message in deserialized_body.messages.iter_mut() {
message.model = None;
// message.model = None;
}
self.user_message = deserialized_body
@ -331,43 +346,45 @@ impl HttpContext for StreamContext {
.last()
.cloned();
let model_name = match self.llm_provider.as_ref() {
Some(llm_provider) => llm_provider.model.as_ref(),
None => None,
};
// let model_name = match self.llm_provider.as_ref() {
// Some(llm_provider) => llm_provider.model.as_ref(),
// None => None,
// };
let use_agent_orchestrator = match self.overrides.as_ref() {
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
None => false,
};
// let use_agent_orchestrator = match self.overrides.as_ref() {
// Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
// None => false,
// };
let model_requested = deserialized_body.model.clone();
if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
deserialized_body.model = match model_name {
Some(model_name) => model_name.clone(),
None => {
if use_agent_orchestrator {
"agent_orchestrator".to_string()
} else {
self.send_server_error(
ServerError::BadRequest {
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
},
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
}
}
}
}
// if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
// deserialized_body.model = match model_name {
// Some(model_name) => model_name.clone(),
// None => {
// if use_agent_orchestrator {
// "agent_orchestrator".to_string()
// } else {
// self.send_server_error(
// ServerError::BadRequest {
// why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
// },
// Some(StatusCode::BAD_REQUEST),
// );
// return Action::Continue;
// }
// }
// }
// }
info!(
"on_http_request_body: provider: {}, model requested: {}, model selected: {}",
"on_http_request_body: provider: {}, model requested: {}, model selected: {:?}",
self.llm_provider().name,
model_requested,
model_name.unwrap_or(&"None".to_string()),
self.llm_provider().model,
);
deserialized_body.model = self.llm_provider().model.clone().unwrap();
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
debug!(
@ -404,7 +421,12 @@ impl HttpContext for StreamContext {
self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes());
Action::Continue
self.chat_completion_request = Some(deserialized_body);
self.request_body = Some(chat_completion_request_str);
self.request_size = Some(body_size);
return Action::Continue;
// return self.route();
}
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
@ -665,4 +687,50 @@ fn current_time_ns() -> u128 {
.as_nanos()
}
impl Context for StreamContext {}
impl Context for StreamContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
debug!(
"on_http_call_response [S={}] token_id={} num_headers={} body_size={} num_trailers={}",
self.context_id, token_id, _num_headers, body_size, _num_trailers
);
let _callout_data = self
.callouts
.borrow_mut()
.remove(&token_id)
.expect("invalid token_id");
let body = self
.get_http_call_response_body(0, body_size)
.unwrap_or_default();
info!(
"on_http_call_response: response body: {}",
String::from_utf8_lossy(&body)
);
self.set_http_request_body(
0,
self.request_size.unwrap(),
self.request_body.as_ref().unwrap().as_bytes(),
);
}
}
impl Client for StreamContext {
type CallContext = CallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
&self.callouts
}
fn active_http_calls(&self) -> &Gauge {
&self.metrics.active_http_calls
}
}