This commit is contained in:
Adil Hafeez 2025-03-24 16:53:36 -07:00
parent 9f59943041
commit 21a5bdd561
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
8 changed files with 239 additions and 47 deletions

View file

@ -56,7 +56,7 @@ def docker_start_archgw_detached(
volume_mappings = [
f"{logs_path_abs}:/var/log:rw",
f"{arch_config_file}:/app/arch_config.yaml:ro",
# "/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro",
"/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro",
]
volume_mappings_args = [
item for volume in volume_mappings for item in ("-v", volume)

56
crates/Cargo.lock generated
View file

@ -64,9 +64,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.90"
version = "1.0.97"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37bf3594c4c988a53154954629820791dde498571819ae4ca50ca811e060cc95"
checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f"
[[package]]
name = "arbitrary"
@ -82,7 +82,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.87",
]
[[package]]
@ -487,7 +487,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.87",
]
[[package]]
@ -919,7 +919,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.87",
]
[[package]]
@ -1061,6 +1061,7 @@ dependencies = [
"http",
"log",
"md5",
"omnillm",
"proxy-wasm",
"proxy-wasm-test-framework",
"rand",
@ -1160,6 +1161,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "omnillm"
version = "0.1.0"
dependencies = [
"anyhow",
"serde",
"serde_json",
]
[[package]]
name = "once_cell"
version = "1.20.2"
@ -1529,29 +1539,29 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.210"
version = "1.0.219"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a"
checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.210"
version = "1.0.219"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f"
checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.87",
]
[[package]]
name = "serde_json"
version = "1.0.130"
version = "1.0.140"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "610f75ff4a8e3cb29b85da56eabdd1bff5b06739059a4b8e2967fef32e5d9944"
checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
dependencies = [
"itoa",
"memchr",
@ -1603,7 +1613,7 @@ checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.87",
]
[[package]]
@ -1711,9 +1721,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.79"
version = "2.0.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590"
checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d"
dependencies = [
"proc-macro2",
"quote",
@ -1728,7 +1738,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.87",
]
[[package]]
@ -1772,7 +1782,7 @@ checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.87",
]
[[package]]
@ -2071,7 +2081,7 @@ dependencies = [
"anyhow",
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.87",
"wasmtime-component-util",
"wasmtime-wit-bindgen",
"wit-parser",
@ -2201,7 +2211,7 @@ checksum = "a2bde986038b819bc43a21fef0610aeb47aabfe3ea09ca3533a7b81023b84ec6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.87",
]
[[package]]
@ -2450,7 +2460,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.87",
"synstructure",
]
@ -2472,7 +2482,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.87",
]
[[package]]
@ -2492,7 +2502,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.87",
"synstructure",
]
@ -2515,7 +2525,7 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.87",
]
[[package]]

View file

@ -1,3 +1,3 @@
[workspace]
resolver = "2"
members = ["llm_gateway", "prompt_gateway", "common"]
members = ["llm_gateway", "prompt_gateway", "common", "omnillm"]

View file

@ -51,7 +51,7 @@ impl Serialize for FunctionParameters {
where
S: serde::Serializer,
{
// select all requried parameters
// select all required parameters
let required: Vec<&String> = self
.properties
.iter()

View file

@ -22,6 +22,7 @@ rand = "0.8.5"
thiserror = "1.0.64"
derivative = "2.2.0"
sha2 = "0.10.8"
omnillm = { version = "0.1.0", path = "../omnillm" }
[dev-dependencies]
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" }

View file

@ -16,12 +16,14 @@ use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use http::StatusCode;
use log::{debug, trace, warn};
use omnillm::{ChatRequest, LlmProvider as LlmProviderTrait, LlmProviders as OmniLlmProviders};
use proxy_wasm::hostcalls::get_current_time;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::collections::VecDeque;
use std::num::NonZero;
use std::rc::Rc;
use std::str::FromStr;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
@ -40,9 +42,10 @@ pub struct StreamContext {
ttft_time: Option<u128>,
traceparent: Option<String>,
request_body_sent_time: Option<u128>,
user_message: Option<Message>,
user_message: Option<omnillm::Message>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
overrides: Rc<Option<Overrides>>,
omni_llm_providers: Rc<OmniLlmProviders>,
}
impl StreamContext {
@ -71,6 +74,7 @@ impl StreamContext {
user_message: None,
traces_queue,
request_body_sent_time: None,
omni_llm_providers: Rc::new(OmniLlmProviders::new()),
}
}
fn llm_provider(&self) -> &LlmProvider {
@ -271,18 +275,17 @@ impl HttpContext for StreamContext {
// Deserialize body into spec.
// Currently OpenAI API.
let mut deserialized_body: ChatCompletionsRequest =
match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
debug!("body str: {}", String::from_utf8_lossy(&body_bytes));
self.send_server_error(
ServerError::Deserialization(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
};
let mut deserialized_body: ChatRequest = match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
debug!("body str: {}", String::from_utf8_lossy(&body_bytes));
self.send_server_error(
ServerError::Deserialization(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
};
// remove metadata from the request body
//TODO: move this to prompt gateway
@ -295,7 +298,7 @@ impl HttpContext for StreamContext {
self.user_message = deserialized_body
.messages
.iter()
.filter(|m| m.role == "user")
.filter(|m| m.role == omnillm::Role::User)
.last()
.cloned();
@ -336,16 +339,32 @@ impl HttpContext for StreamContext {
model_name,
);
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
let omn_provider_type =
omnillm::Provider::from_str(&self.llm_provider().provider_interface.to_string())
.unwrap();
trace!("request body: {}", chat_completion_request_str);
let chat_request_bytes = self
.omni_llm_providers
.as_ref()
.providers
.get(&omn_provider_type)
.unwrap()
.translate_request(&deserialized_body)
.unwrap();
if deserialized_body.stream {
trace!(
"request body str: {}",
String::from_utf8_lossy(&chat_request_bytes)
);
if deserialized_body.stream.unwrap_or_default() {
self.streaming_response = true;
}
if deserialized_body.stream && deserialized_body.stream_options.is_none() {
deserialized_body.stream_options = Some(StreamOptions {
include_usage: true,
if deserialized_body.stream.unwrap_or_default()
&& deserialized_body.stream_options.is_none()
{
deserialized_body.stream_options = Some(omnillm::StreamOptions {
include_usage: Some(true),
});
}
@ -367,7 +386,7 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes());
self.set_http_request_body(0, body_size, &chat_request_bytes);
Action::Continue
}

View file

@ -0,0 +1,9 @@
[package]
name = "omnillm"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = "1.0.97"
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"

153
crates/omnillm/src/lib.rs Normal file
View file

@ -0,0 +1,153 @@
use std::{collections::HashMap, str::FromStr};
use anyhow::Result;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Hash, PartialEq, Eq)]
pub enum Provider {
OpenAI,
Mistral,
}
impl FromStr for Provider {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"openai" => Ok(Provider::OpenAI),
"mistral" => Ok(Provider::Mistral),
_ => Err(anyhow::anyhow!(format!("Invalid provider: {}", s))),
}
}
}
#[derive(Serialize, Deserialize)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<i64>,
}
#[derive(Serialize, Deserialize)]
pub struct StreamOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub include_usage: Option<bool>,
}
#[derive(Serialize, Deserialize, PartialEq, Clone)]
pub enum Role {
#[serde(rename = "system")]
System,
#[serde(rename = "user")]
User,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct Message {
pub role: Role,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
}
#[derive(Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: i64,
pub completion_tokens: i64,
pub total: i64,
}
#[derive(Serialize, Deserialize)]
pub struct ChatResponse {
pub messages: Vec<Message>,
pub usage: Option<Usage>,
}
pub trait LlmProvider {
fn translate_request(&self, request: &ChatRequest) -> Result<Vec<u8>>;
fn translate_response(&self, response: &Vec<u8>) -> Result<ChatResponse>;
}
pub struct LlmProviders {
pub providers: HashMap<Provider, Box<dyn LlmProvider>>,
}
impl LlmProviders {
pub fn new() -> LlmProviders {
LlmProviders {
providers: HashMap::from([
(Provider::OpenAI, Box::new(OpenAI) as Box<dyn LlmProvider>),
(Provider::Mistral, Box::new(Mistral) as Box<dyn LlmProvider>),
]),
}
}
}
pub struct OpenAI;
impl LlmProvider for OpenAI {
fn translate_request(&self, request: &ChatRequest) -> Result<Vec<u8>> {
serde_json::to_string(request)
.map(|s| s.into_bytes())
.map_err(Into::into)
}
fn translate_response(&self, response: &Vec<u8>) -> Result<ChatResponse> {
serde_json::from_slice(response).map_err(Into::into)
}
}
pub struct Mistral;
impl LlmProvider for Mistral {
fn translate_request(&self, request: &ChatRequest) -> Result<Vec<u8>> {
serde_json::to_string(request)
.map(|s| s.into_bytes())
.map_err(Into::into)
}
fn translate_response(&self, response: &Vec<u8>) -> Result<ChatResponse> {
serde_json::from_slice(response).map_err(Into::into)
}
}
#[cfg(test)]
mod test {
#[test]
fn test_translate_request_openai() {
use super::{ChatRequest, LlmProvider, Message, OpenAI, Role};
let request = ChatRequest {
model: "gpt-3.5-turbo".to_string(),
messages: vec![
Message {
role: Role::System,
content: Some("You are a helpful assistant.".to_string()),
model: None,
},
Message {
role: Role::User,
content: Some("I need help with my computer.".to_string()),
model: None,
},
],
temperature: None,
max_tokens: None,
stream: None,
stream_options: None,
};
let openai = OpenAI;
let result = openai.translate_request(&request).unwrap();
let expected = r#"{"model":"gpt-3.5-turbo","messages":[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"I need help with my computer."}]}"#;
assert_eq!(String::from_utf8(result).unwrap(), expected);
}
}