mirror of
https://github.com/katanemo/plano.git
synced 2026-07-02 15:51:02 +02:00
initial commit
This commit is contained in:
parent
1f95fac4af
commit
1d19f0c2f7
36 changed files with 3003 additions and 109 deletions
|
|
@ -3,6 +3,8 @@ use proxy_wasm::traits::*;
|
|||
use proxy_wasm::types::*;
|
||||
|
||||
mod filter_context;
|
||||
mod llm_routing;
|
||||
mod llm_routing_consts;
|
||||
mod metrics;
|
||||
mod stream_context;
|
||||
|
||||
|
|
|
|||
106
crates/llm_gateway/src/llm_routing.rs
Normal file
106
crates/llm_gateway/src/llm_routing.rs
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
// use std::rc::Rc;
|
||||
// use std::time::Duration;
|
||||
|
||||
// use common::api::open_ai::{ChatCompletionsRequest, Message};
|
||||
// use common::configuration::LlmProvider;
|
||||
// use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, ARCH_UPSTREAM_HOST_HEADER};
|
||||
// use common::errors::ServerError;
|
||||
// use common::http::{CallArgs, Client};
|
||||
// use log::{info, warn};
|
||||
// use proxy_wasm::traits::HttpContext;
|
||||
// use proxy_wasm::types::Action;
|
||||
|
||||
// use crate::llm_routing_consts::SYSTEM_PROMPT;
|
||||
// use crate::stream_context::{CallContext, StreamContext};
|
||||
|
||||
// pub trait Routing {
|
||||
// fn route(&self) -> Action;
|
||||
// }
|
||||
|
||||
// impl Routing for StreamContext {
|
||||
// fn route(&self) -> Action {
|
||||
// let usage_based_providers = self
|
||||
// .llm_providers
|
||||
// .iter()
|
||||
// .filter(|(_, provider)| provider.usage.is_some())
|
||||
// .map(|(_, provider)| provider.clone())
|
||||
// .collect::<Vec<Rc<LlmProvider>>>();
|
||||
|
||||
// info!(
|
||||
// "usage based providers found: {}",
|
||||
// usage_based_providers
|
||||
// .iter()
|
||||
// .map(|provider| provider.name.clone())
|
||||
// .collect::<Vec<String>>()
|
||||
// .join(", ")
|
||||
// );
|
||||
|
||||
// if usage_based_providers.is_empty() {
|
||||
// self.set_http_request_body(
|
||||
// 0,
|
||||
// self.request_size.unwrap(),
|
||||
// self.request_body.as_ref().unwrap().as_bytes(),
|
||||
// );
|
||||
// return Action::Continue;
|
||||
// }
|
||||
|
||||
// let llm_routes_str = r#"- name: gpt-4o
|
||||
// description: simple requests, basic fact retrieval, easy to answer
|
||||
// - name: o4-mini()
|
||||
// description: complex reasoning problem, require multi step answer"#;
|
||||
|
||||
// let chat_completions_request_messages_str =
|
||||
// serde_json::to_string(&self.chat_completion_request.as_ref().unwrap().messages)
|
||||
// .expect("failed to serialize llm routing request messages");
|
||||
|
||||
// let system_prompt_formatted = SYSTEM_PROMPT
|
||||
// .replace("{routes}", llm_routes_str)
|
||||
// .replace("{conversation}", &chat_completions_request_messages_str);
|
||||
|
||||
// let message = Message {
|
||||
// role: "user".to_string(),
|
||||
// content: Some(system_prompt_formatted),
|
||||
// model: None,
|
||||
// tool_calls: None,
|
||||
// tool_call_id: None,
|
||||
// };
|
||||
|
||||
// let llm_routing_request = ChatCompletionsRequest {
|
||||
// model: "cotran2/llama-1b-4-26".to_string(),
|
||||
// messages: vec![message],
|
||||
// tools: None,
|
||||
// stream: false,
|
||||
// stream_options: None,
|
||||
// metadata: None,
|
||||
// };
|
||||
|
||||
// let llm_routing_request_str = serde_json::to_string(&llm_routing_request)
|
||||
// .expect("failed to serialize llm routing request");
|
||||
|
||||
// let headers = vec![
|
||||
// (":method", "POST"),
|
||||
// (ARCH_UPSTREAM_HOST_HEADER, "gcp_hosted_outer_llm"),
|
||||
// (":path", "/v1/chat/completions"),
|
||||
// (":authority", "gcp_hosted_outer_llm"),
|
||||
// ("content-type", "application/json"),
|
||||
// ("x-envoy-max-retries", "3"),
|
||||
// ("x-envoy-upstream-rq-timeout-ms", "5000"),
|
||||
// ];
|
||||
|
||||
// let call_args = CallArgs::new(
|
||||
// ARCH_INTERNAL_CLUSTER_NAME,
|
||||
// "/v1/chat/completions",
|
||||
// headers,
|
||||
// llm_routing_request_str.as_bytes().into(),
|
||||
// vec![],
|
||||
// Duration::from_secs(5),
|
||||
// );
|
||||
|
||||
// if let Err(e) = self.http_call(call_args, CallContext {}) {
|
||||
// warn!("failed to call LLM provider: {}", e);
|
||||
// self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
// }
|
||||
|
||||
// Action::Pause
|
||||
// }
|
||||
// }
|
||||
32
crates/llm_gateway/src/llm_routing_consts.rs
Normal file
32
crates/llm_gateway/src/llm_routing_consts.rs
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
// pub const SYSTEM_PROMPT: &str = r#"
|
||||
// You are an advanced Routing Assistant designed to select the optimal route based on user requests.
|
||||
// Your task is to analyze conversations and match them to the most appropriate predefined route.
|
||||
// Review the available routes config:
|
||||
|
||||
// # ROUTES CONFIG START
|
||||
// {routes}
|
||||
// # ROUTES CONFIG END
|
||||
|
||||
// Examine the following conversation between a user and an assistant:
|
||||
|
||||
// # CONVERSATION START
|
||||
// {conversation}
|
||||
// # CONVERSATION END
|
||||
|
||||
// Your goal is to identify the most appropriate route that matches the user's LATEST intent. Follow these steps:
|
||||
|
||||
// 1. Carefully read and analyze the provided conversation, focusing on the user's latest request and the conversation scenario.
|
||||
// 2. Check if the user's request and scenario matches any of the routes in the routing configuration (focus on the description).
|
||||
// 3. Find the route that best matches.
|
||||
// 4. Use context clues from the entire conversation to determine the best fit.
|
||||
// 5. Return the best match possible. You only response the name of the route that best matches the user's request, use the exact name in the routes config.
|
||||
// 6. If no route relatively close to matches the user's latest intent or user last message is thank you or greeting, return an empty route ''.
|
||||
|
||||
// # OUTPUT FORMAT
|
||||
// Your final output must follow this JSON format:
|
||||
// {
|
||||
// "route": "route_name" # The matched route name, or empty string '' if no match
|
||||
// }
|
||||
|
||||
// Based on your analysis, provide only the JSON object as your final output with no additional text, explanations, or whitespace.
|
||||
// "#;
|
||||
|
|
@ -9,9 +9,10 @@ use common::consts::{
|
|||
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||
};
|
||||
use common::errors::ServerError;
|
||||
use common::http::Client;
|
||||
use common::llm_providers::LlmProviders;
|
||||
use common::ratelimit::Header;
|
||||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
use common::stats::{Gauge, IncrementingMetric, RecordingMetric};
|
||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use http::StatusCode;
|
||||
|
|
@ -19,12 +20,16 @@ use log::{debug, info, warn};
|
|||
use proxy_wasm::hostcalls::get_current_time;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use std::collections::VecDeque;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::num::NonZero;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CallContext {}
|
||||
|
||||
pub struct StreamContext {
|
||||
context_id: u32,
|
||||
metrics: Rc<Metrics>,
|
||||
|
|
@ -32,7 +37,7 @@ pub struct StreamContext {
|
|||
streaming_response: bool,
|
||||
response_tokens: usize,
|
||||
is_chat_completions_request: bool,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
pub(crate) llm_providers: Rc<LlmProviders>,
|
||||
llm_provider: Option<Rc<LlmProvider>>,
|
||||
request_id: Option<String>,
|
||||
start_time: SystemTime,
|
||||
|
|
@ -43,6 +48,10 @@ pub struct StreamContext {
|
|||
user_message: Option<Message>,
|
||||
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
pub(crate) request_body: Option<String>,
|
||||
pub(crate) request_size: Option<usize>,
|
||||
pub(crate) chat_completion_request: Option<ChatCompletionsRequest>,
|
||||
callouts: RefCell<HashMap<u32, CallContext>>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -71,8 +80,13 @@ impl StreamContext {
|
|||
user_message: None,
|
||||
traces_queue,
|
||||
request_body_sent_time: None,
|
||||
request_body: None,
|
||||
request_size: None,
|
||||
chat_completion_request: None,
|
||||
callouts: RefCell::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn llm_provider(&self) -> &LlmProvider {
|
||||
self.llm_provider
|
||||
.as_ref()
|
||||
|
|
@ -156,7 +170,7 @@ impl StreamContext {
|
|||
});
|
||||
}
|
||||
|
||||
fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
|
||||
pub fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
|
||||
warn!("server error occurred: {}", error);
|
||||
self.send_http_response(
|
||||
override_status_code
|
||||
|
|
@ -228,6 +242,7 @@ impl HttpContext for StreamContext {
|
|||
stream: None,
|
||||
port: None,
|
||||
rate_limits: None,
|
||||
usage: None,
|
||||
}));
|
||||
} else {
|
||||
self.select_llm_provider();
|
||||
|
|
@ -321,7 +336,7 @@ impl HttpContext for StreamContext {
|
|||
// deserialized_body.metadata = None;
|
||||
// delete model key from message array
|
||||
for message in deserialized_body.messages.iter_mut() {
|
||||
message.model = None;
|
||||
// message.model = None;
|
||||
}
|
||||
|
||||
self.user_message = deserialized_body
|
||||
|
|
@ -331,43 +346,45 @@ impl HttpContext for StreamContext {
|
|||
.last()
|
||||
.cloned();
|
||||
|
||||
let model_name = match self.llm_provider.as_ref() {
|
||||
Some(llm_provider) => llm_provider.model.as_ref(),
|
||||
None => None,
|
||||
};
|
||||
// let model_name = match self.llm_provider.as_ref() {
|
||||
// Some(llm_provider) => llm_provider.model.as_ref(),
|
||||
// None => None,
|
||||
// };
|
||||
|
||||
let use_agent_orchestrator = match self.overrides.as_ref() {
|
||||
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
|
||||
None => false,
|
||||
};
|
||||
// let use_agent_orchestrator = match self.overrides.as_ref() {
|
||||
// Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
|
||||
// None => false,
|
||||
// };
|
||||
|
||||
let model_requested = deserialized_body.model.clone();
|
||||
if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
|
||||
deserialized_body.model = match model_name {
|
||||
Some(model_name) => model_name.clone(),
|
||||
None => {
|
||||
if use_agent_orchestrator {
|
||||
"agent_orchestrator".to_string()
|
||||
} else {
|
||||
self.send_server_error(
|
||||
ServerError::BadRequest {
|
||||
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
|
||||
// deserialized_body.model = match model_name {
|
||||
// Some(model_name) => model_name.clone(),
|
||||
// None => {
|
||||
// if use_agent_orchestrator {
|
||||
// "agent_orchestrator".to_string()
|
||||
// } else {
|
||||
// self.send_server_error(
|
||||
// ServerError::BadRequest {
|
||||
// why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
|
||||
// },
|
||||
// Some(StatusCode::BAD_REQUEST),
|
||||
// );
|
||||
// return Action::Continue;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
info!(
|
||||
"on_http_request_body: provider: {}, model requested: {}, model selected: {}",
|
||||
"on_http_request_body: provider: {}, model requested: {}, model selected: {:?}",
|
||||
self.llm_provider().name,
|
||||
model_requested,
|
||||
model_name.unwrap_or(&"None".to_string()),
|
||||
self.llm_provider().model,
|
||||
);
|
||||
|
||||
deserialized_body.model = self.llm_provider().model.clone().unwrap();
|
||||
|
||||
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
|
||||
|
||||
debug!(
|
||||
|
|
@ -404,7 +421,12 @@ impl HttpContext for StreamContext {
|
|||
|
||||
self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes());
|
||||
|
||||
Action::Continue
|
||||
self.chat_completion_request = Some(deserialized_body);
|
||||
self.request_body = Some(chat_completion_request_str);
|
||||
self.request_size = Some(body_size);
|
||||
|
||||
return Action::Continue;
|
||||
// return self.route();
|
||||
}
|
||||
|
||||
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
|
|
@ -665,4 +687,50 @@ fn current_time_ns() -> u128 {
|
|||
.as_nanos()
|
||||
}
|
||||
|
||||
impl Context for StreamContext {}
|
||||
impl Context for StreamContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
debug!(
|
||||
"on_http_call_response [S={}] token_id={} num_headers={} body_size={} num_trailers={}",
|
||||
self.context_id, token_id, _num_headers, body_size, _num_trailers
|
||||
);
|
||||
|
||||
let _callout_data = self
|
||||
.callouts
|
||||
.borrow_mut()
|
||||
.remove(&token_id)
|
||||
.expect("invalid token_id");
|
||||
|
||||
let body = self
|
||||
.get_http_call_response_body(0, body_size)
|
||||
.unwrap_or_default();
|
||||
|
||||
info!(
|
||||
"on_http_call_response: response body: {}",
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
|
||||
self.set_http_request_body(
|
||||
0,
|
||||
self.request_size.unwrap(),
|
||||
self.request_body.as_ref().unwrap().as_bytes(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl Client for StreamContext {
|
||||
type CallContext = CallContext;
|
||||
|
||||
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
|
||||
&self.callouts
|
||||
}
|
||||
|
||||
fn active_http_calls(&self) -> &Gauge {
|
||||
&self.metrics.active_http_calls
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue